From 6585f49841ada637b0811e0aadcf93132fff7001 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Wed, 21 Aug 2013 14:51:56 -0700 Subject: Update build docs --- README.md | 46 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) (limited to 'README.md') diff --git a/README.md b/README.md index 1dd96a0a4a..1e388ff380 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Spark requires Scala 2.9.3 (Scala 2.10 is not yet supported). The project is built using Simple Build Tool (SBT), which is packaged with it. To build Spark and its example programs, run: - sbt/sbt package + sbt/sbt package assembly Spark also supports building using Maven. If you would like to build using Maven, see the [instructions for building Spark with Maven](http://spark-project.org/docs/latest/building-with-maven.html) @@ -43,10 +43,48 @@ locally with one thread, or "local[N]" to run locally with N threads. ## A Note About Hadoop Versions Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported -storage systems. Because the HDFS API has changed in different versions of +storage systems. Because the protocols have changed in different versions of Hadoop, you must build Spark against the same version that your cluster runs. -You can change the version by setting the `HADOOP_VERSION` variable at the top -of `project/SparkBuild.scala`, then rebuilding Spark. +You can change the version by setting the `SPARK_HADOOP_VERSION` environment +when building Spark. + +For Apache Hadoop versions 1.x, 0.20.x, Cloudera CDH MRv1, and other Hadoop +versions without YARN, use: + + # Apache Hadoop 1.2.1 + $ SPARK_HADOOP_VERSION=1.2.1 sbt/sbt package assembly + + # Cloudera CDH 4.2.0 with MapReduce v1 + $ SPARK_HADOOP_VERSION=2.0.0-mr1-cdh4.2.0 sbt/sbt package assembly + +For Apache Hadoop 2.x, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions +with YARN, also set `SPARK_WITH_YARN=true`: + + # Apache Hadoop 2.0.5-alpha + $ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_WITH_YARN=true sbt/sbt package assembly + + # Cloudera CDH 4.2.0 with MapReduce v2 + $ SPARK_HADOOP_VERSION=2.0.0-cdh4.2.0 SPARK_WITH_YARN=true sbt/sbt package assembly + +For convenience, these variables may also be set through the `conf/spark-env.sh` file +described below. + +When developing a Spark application, specify the Hadoop version by adding the +"hadoop-client" artifact to your project's dependencies. For example, if you're +using Hadoop 0.23.9 and build your application using SBT, add this to +`libraryDependencies`: + + // "force()" is required because "0.23.9" is less than Spark's default of "1.0.4" + "org.apache.hadoop" % "hadoop-client" % "0.23.9" force() + +If your project is built with Maven, add this to your POM file's `` section: + + + org.apache.hadoop + hadoop-client + + [0.23.9] + ## Configuration -- cgit v1.2.3 From 4d737b6d320993432f286e05b773135fc3334ce3 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Wed, 21 Aug 2013 14:56:25 -0700 Subject: Example should make sense --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'README.md') diff --git a/README.md b/README.md index 1e388ff380..9b2c1a9c4d 100644 --- a/README.md +++ b/README.md @@ -71,19 +71,19 @@ described below. When developing a Spark application, specify the Hadoop version by adding the "hadoop-client" artifact to your project's dependencies. For example, if you're -using Hadoop 0.23.9 and build your application using SBT, add this to +using Hadoop 1.0.1 and build your application using SBT, add this to `libraryDependencies`: - // "force()" is required because "0.23.9" is less than Spark's default of "1.0.4" - "org.apache.hadoop" % "hadoop-client" % "0.23.9" force() + // "force()" is required because "1.0.1" is less than Spark's default of "1.0.4" + "org.apache.hadoop" % "hadoop-client" % "1.0.1" force() If your project is built with Maven, add this to your POM file's `` section: org.apache.hadoop hadoop-client - - [0.23.9] + + [1.0.1] -- cgit v1.2.3 From f9cc1fbf272606ce0b679936ef38429b20916cc1 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Wed, 21 Aug 2013 17:12:03 -0700 Subject: Remove references to unsupported Hadoop versions --- README.md | 2 +- docs/building-with-maven.md | 2 +- project/SparkBuild.scala | 7 +++---- 3 files changed, 5 insertions(+), 6 deletions(-) (limited to 'README.md') diff --git a/README.md b/README.md index 9b2c1a9c4d..8502410c52 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ Hadoop, you must build Spark against the same version that your cluster runs. You can change the version by setting the `SPARK_HADOOP_VERSION` environment when building Spark. -For Apache Hadoop versions 1.x, 0.20.x, Cloudera CDH MRv1, and other Hadoop +For Apache Hadoop versions 1.x, Cloudera CDH MRv1, and other Hadoop versions without YARN, use: # Apache Hadoop 1.2.1 diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index d71d94fa63..a9f2cb8a7a 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -12,7 +12,7 @@ Building Spark using Maven Requires Maven 3 (the build process is tested with Ma To enable support for HDFS and other Hadoop-supported storage systems, specify the exact Hadoop version by setting the "hadoop.version" property. If unset, Spark will build against Hadoop 1.0.4 by default. -For Apache Hadoop versions 1.x, 0.20.x, Cloudera CDH MRv1, and other Hadoop versions without YARN, use: +For Apache Hadoop versions 1.x, Cloudera CDH MRv1, and other Hadoop versions without YARN, use: # Apache Hadoop 1.2.1 $ mvn -Dhadoop.version=1.2.1 clean install diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 81080741ca..cc09bf7dd2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -24,10 +24,9 @@ import AssemblyKeys._ //import com.jsuereth.pgp.sbtplugin.PgpKeys._ object SparkBuild extends Build { - // Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or - // "1.0.4" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop. - // Note that these variables can be set through the environment variables - // SPARK_HADOOP_VERSION and SPARK_WITH_YARN. + // Hadoop version to build against. For example, "1.0.4" for Apache releases, or + // "2.0.0-mr1-cdh4.2.0" for Cloudera Hadoop. Note that these variables can be set + // through the environment variables SPARK_HADOOP_VERSION and SPARK_WITH_YARN. val DEFAULT_HADOOP_VERSION = "1.0.4" val DEFAULT_WITH_YARN = false -- cgit v1.2.3 From 0087b43e9cddc726f661e1e047e63390d5d9b419 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Wed, 21 Aug 2013 21:15:00 -0700 Subject: Use Hadoop 1.2.1 in application example --- README.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'README.md') diff --git a/README.md b/README.md index 8502410c52..e5f527b84a 100644 --- a/README.md +++ b/README.md @@ -71,19 +71,18 @@ described below. When developing a Spark application, specify the Hadoop version by adding the "hadoop-client" artifact to your project's dependencies. For example, if you're -using Hadoop 1.0.1 and build your application using SBT, add this to +using Hadoop 1.0.1 and build your application using SBT, add this entry to `libraryDependencies`: - // "force()" is required because "1.0.1" is less than Spark's default of "1.0.4" - "org.apache.hadoop" % "hadoop-client" % "1.0.1" force() + "org.apache.hadoop" % "hadoop-client" % "1.2.1" If your project is built with Maven, add this to your POM file's `` section: org.apache.hadoop hadoop-client - - [1.0.1] + + [1.2.1] -- cgit v1.2.3 From 53cd50c0699efc8733518658100c62426b425de2 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 23 Aug 2013 23:30:17 -0700 Subject: Change build and run instructions to use assemblies This commit makes Spark invocation saner by using an assembly JAR to find all of Spark's dependencies instead of adding all the JARs in lib_managed. It also packages the examples into an assembly and uses that as SPARK_EXAMPLES_JAR. Finally, it replaces the old "run" script with two better-named scripts: "run-examples" for examples, and "spark-class" for Spark internal classes (e.g. REPL, master, etc). This is also designed to minimize the confusion people have in trying to use "run" to run their own classes; it's not meant to do that, but now at least if they look at it, they can modify run-examples to do a decent job for them. As part of this, Bagel's examples are also now properly moved to the examples package instead of bagel. --- README.md | 6 +- .../scala/spark/bagel/examples/PageRankUtils.scala | 123 ------------ .../spark/bagel/examples/WikipediaPageRank.scala | 101 ---------- .../examples/WikipediaPageRankStandalone.scala | 223 --------------------- bin/compute-classpath.sh | 82 ++------ bin/spark-daemon.sh | 8 +- .../mesos/CoarseMesosSchedulerBackend.scala | 2 +- core/src/test/scala/spark/DriverSuite.scala | 2 +- docs/bagel-programming-guide.md | 2 +- docs/index.md | 4 +- docs/java-programming-guide.md | 2 +- docs/running-on-yarn.md | 8 +- docs/scala-programming-guide.md | 2 +- docs/spark-standalone.md | 4 +- docs/streaming-programming-guide.md | 4 +- examples/pom.xml | 5 + .../scala/spark/examples/bagel/PageRankUtils.scala | 123 ++++++++++++ .../spark/examples/bagel/WikipediaPageRank.scala | 101 ++++++++++ .../bagel/WikipediaPageRankStandalone.scala | 223 +++++++++++++++++++++ make-distribution.sh | 17 +- project/SparkBuild.scala | 52 +++-- project/build.properties | 2 +- project/plugins.sbt | 2 +- python/pyspark/java_gateway.py | 2 +- run | 175 ---------------- run-example | 76 +++++++ spark-class | 124 ++++++++++++ spark-executor | 2 +- spark-shell | 3 +- 29 files changed, 734 insertions(+), 746 deletions(-) delete mode 100644 bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala delete mode 100644 bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala delete mode 100644 bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala create mode 100644 examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala create mode 100644 examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala create mode 100644 examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala delete mode 100755 run create mode 100755 run-example create mode 100755 spark-class (limited to 'README.md') diff --git a/README.md b/README.md index e5f527b84a..e7af34f513 100644 --- a/README.md +++ b/README.md @@ -20,16 +20,16 @@ Spark and its example programs, run: Spark also supports building using Maven. If you would like to build using Maven, see the [instructions for building Spark with Maven](http://spark-project.org/docs/latest/building-with-maven.html) -in the spark documentation.. +in the Spark documentation.. To run Spark, you will need to have Scala's bin directory in your `PATH`, or you will need to set the `SCALA_HOME` environment variable to point to where you've installed Scala. Scala must be accessible through one of these methods on your cluster's worker nodes as well as its master. -To run one of the examples, use `./run `. For example: +To run one of the examples, use `./run-example `. For example: - ./run spark.examples.SparkLR local[2] + ./run-example spark.examples.SparkLR local[2] will run the Logistic Regression example locally on 2 CPUs. diff --git a/bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala b/bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala deleted file mode 100644 index de65e27fe0..0000000000 --- a/bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.bagel.examples - -import spark._ -import spark.SparkContext._ - -import spark.bagel._ -import spark.bagel.Bagel._ - -import scala.collection.mutable.ArrayBuffer - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} - -import com.esotericsoftware.kryo._ - -class PageRankUtils extends Serializable { - def computeWithCombiner(numVertices: Long, epsilon: Double)( - self: PRVertex, messageSum: Option[Double], superstep: Int - ): (PRVertex, Array[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = superstep >= 10 - - val outbox: Array[PRMessage] = - if (!terminate) - self.outEdges.map(targetId => - new PRMessage(targetId, newValue / self.outEdges.size)) - else - Array[PRMessage]() - - (new PRVertex(newValue, self.outEdges, !terminate), outbox) - } - - def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) = - computeWithCombiner(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -class PRCombiner extends Combiner[PRMessage, Double] with Serializable { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b -} - -class PRVertex() extends Vertex with Serializable { - var value: Double = _ - var outEdges: Array[String] = _ - var active: Boolean = _ - - def this(value: Double, outEdges: Array[String], active: Boolean = true) { - this() - this.value = value - this.outEdges = outEdges - this.active = active - } - - override def toString(): String = { - "PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString) - } -} - -class PRMessage() extends Message[String] with Serializable { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -class PRKryoRegistrator extends KryoRegistrator { - def registerClasses(kryo: Kryo) { - kryo.register(classOf[PRVertex]) - kryo.register(classOf[PRMessage]) - } -} - -class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions - - def getPartition(key: Any): Int = { - val hash = key match { - case k: Long => (k & 0x00000000FFFFFFFFL).toInt - case _ => key.hashCode - } - - val mod = key.hashCode % partitions - if (mod < 0) mod + partitions else mod - } - - override def equals(other: Any): Boolean = other match { - case c: CustomPartitioner => - c.numPartitions == numPartitions - case _ => false - } -} diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala deleted file mode 100644 index a0c5ac9c18..0000000000 --- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.bagel.examples - -import spark._ -import spark.SparkContext._ - -import spark.bagel._ -import spark.bagel.Bagel._ - -import scala.xml.{XML,NodeSeq} - -/** - * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" - * files from there, which contains one line per wiki article in a tab-separated format - * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). - */ -object WikipediaPageRank { - def main(args: Array[String]) { - if (args.length < 5) { - System.err.println("Usage: WikipediaPageRank ") - System.exit(-1) - } - - System.setProperty("spark.serializer", "spark.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) - - val inputFile = args(0) - val threshold = args(1).toDouble - val numPartitions = args(2).toInt - val host = args(3) - val usePartitioner = args(4).toBoolean - val sc = new SparkContext(host, "WikipediaPageRank") - - // Parse the Wikipedia page data into a graph - val input = sc.textFile(inputFile) - - println("Counting vertices...") - val numVertices = input.count() - println("Done counting vertices.") - - println("Parsing input file...") - var vertices = input.map(line => { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val links = - if (body == "\\N") - NodeSeq.Empty - else - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) - NodeSeq.Empty - } - val outEdges = links.map(link => new String(link.text)).toArray - val id = new String(title) - (id, new PRVertex(1.0 / numVertices, outEdges)) - }) - if (usePartitioner) - vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache - else - vertices = vertices.cache - println("Done parsing input file.") - - // Do the computation - val epsilon = 0.01 / numVertices - val messages = sc.parallelize(Array[(String, PRMessage)]()) - val utils = new PageRankUtils - val result = - Bagel.run( - sc, vertices, messages, combiner = new PRCombiner(), - numPartitions = numPartitions)( - utils.computeWithCombiner(numVertices, epsilon)) - - // Print the result - System.err.println("Articles with PageRank >= "+threshold+":") - val top = - (result - .filter { case (id, vertex) => vertex.value >= threshold } - .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } - .collect.mkString) - println(top) - } -} diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala deleted file mode 100644 index 3c54a85f42..0000000000 --- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.bagel.examples - -import spark._ -import serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import spark.SparkContext._ - -import spark.bagel._ -import spark.bagel.Bagel._ - -import scala.xml.{XML,NodeSeq} - -import scala.collection.mutable.ArrayBuffer - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer - -object WikipediaPageRankStandalone { - def main(args: Array[String]) { - if (args.length < 5) { - System.err.println("Usage: WikipediaPageRankStandalone ") - System.exit(-1) - } - - System.setProperty("spark.serializer", "spark.bagel.examples.WPRSerializer") - - val inputFile = args(0) - val threshold = args(1).toDouble - val numIterations = args(2).toInt - val host = args(3) - val usePartitioner = args(4).toBoolean - val sc = new SparkContext(host, "WikipediaPageRankStandalone") - - val input = sc.textFile(inputFile) - val partitioner = new HashPartitioner(sc.defaultParallelism) - val links = - if (usePartitioner) - input.map(parseArticle _).partitionBy(partitioner).cache() - else - input.map(parseArticle _).cache() - val n = links.count() - val defaultRank = 1.0 / n - val a = 0.15 - - // Do the computation - val startTime = System.currentTimeMillis - val ranks = - pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, sc.defaultParallelism) - - // Print the result - System.err.println("Articles with PageRank >= "+threshold+":") - val top = - (ranks - .filter { case (id, rank) => rank >= threshold } - .map { case (id, rank) => "%s\t%s\n".format(id, rank) } - .collect().mkString) - println(top) - - val time = (System.currentTimeMillis - startTime) / 1000.0 - println("Completed %d iterations in %f seconds: %f seconds per iteration" - .format(numIterations, time, time / numIterations)) - System.exit(0) - } - - def parseArticle(line: String): (String, Array[String]) = { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val id = new String(title) - val links = - if (body == "\\N") - NodeSeq.Empty - else - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) - NodeSeq.Empty - } - val outEdges = links.map(link => new String(link.text)).toArray - (id, outEdges) - } - - def pageRank( - links: RDD[(String, Array[String])], - numIterations: Int, - defaultRank: Double, - a: Double, - n: Long, - partitioner: Partitioner, - usePartitioner: Boolean, - numPartitions: Int - ): RDD[(String, Double)] = { - var ranks = links.mapValues { edges => defaultRank } - for (i <- 1 to numIterations) { - val contribs = links.groupWith(ranks).flatMap { - case (id, (linksWrapper, rankWrapper)) => - if (linksWrapper.length > 0) { - if (rankWrapper.length > 0) { - linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size)) - } else { - linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size)) - } - } else { - Array[(String, Double)]() - } - } - ranks = (contribs.combineByKey((x: Double) => x, - (x: Double, y: Double) => x + y, - (x: Double, y: Double) => x + y, - partitioner) - .mapValues(sum => a/n + (1-a)*sum)) - } - ranks - } -} - -class WPRSerializer extends spark.serializer.Serializer { - def newInstance(): SerializerInstance = new WPRSerializerInstance() -} - -class WPRSerializerInstance extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { - throw new UnsupportedOperationException() - } - - def deserialize[T](bytes: ByteBuffer): T = { - throw new UnsupportedOperationException() - } - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - throw new UnsupportedOperationException() - } - - def serializeStream(s: OutputStream): SerializationStream = { - new WPRSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new WPRDeserializationStream(s) - } -} - -class WPRSerializationStream(os: OutputStream) extends SerializationStream { - val dos = new DataOutputStream(os) - - def writeObject[T](t: T): SerializationStream = t match { - case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { - case links: Array[String] => { - dos.writeInt(0) // links - dos.writeUTF(id) - dos.writeInt(links.length) - for (link <- links) { - dos.writeUTF(link) - } - this - } - case rank: Double => { - dos.writeInt(1) // rank - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - case (id: String, rank: Double) => { - dos.writeInt(2) // rank without wrapper - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - - def flush() { dos.flush() } - def close() { dos.close() } -} - -class WPRDeserializationStream(is: InputStream) extends DeserializationStream { - val dis = new DataInputStream(is) - - def readObject[T](): T = { - val typeId = dis.readInt() - typeId match { - case 0 => { - val id = dis.readUTF() - val numLinks = dis.readInt() - val links = new Array[String](numLinks) - for (i <- 0 until numLinks) { - val link = dis.readUTF() - links(i) = link - } - (id, ArrayBuffer(links)).asInstanceOf[T] - } - case 1 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, ArrayBuffer(rank)).asInstanceOf[T] - } - case 2 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, rank).asInstanceOf[T] - } - } - } - - def close() { dis.close() } -} diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 7a21b3c4a1..5dc86c51a4 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -30,79 +30,25 @@ if [ -e $FWDIR/conf/spark-env.sh ] ; then . $FWDIR/conf/spark-env.sh fi -CORE_DIR="$FWDIR/core" -REPL_DIR="$FWDIR/repl" -REPL_BIN_DIR="$FWDIR/repl-bin" -EXAMPLES_DIR="$FWDIR/examples" -BAGEL_DIR="$FWDIR/bagel" -MLLIB_DIR="$FWDIR/mllib" -TOOLS_DIR="$FWDIR/tools" -YARN_DIR="$FWDIR/yarn" -STREAMING_DIR="$FWDIR/streaming" -PYSPARK_DIR="$FWDIR/python" - # Build up classpath -CLASSPATH="$SPARK_CLASSPATH" - -function dev_classpath { - CLASSPATH="$CLASSPATH:$FWDIR/conf" - CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes" - if [ -n "$SPARK_TESTING" ] ; then - CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes" - fi - CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources" - CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar - if [ -e "$FWDIR/lib_managed" ]; then - CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*" - CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*" - fi - CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*" - # Add the shaded JAR for Maven builds - if [ -e $REPL_BIN_DIR/target ]; then - for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded.jar'`; do - CLASSPATH="$CLASSPATH:$jar" - done - # The shaded JAR doesn't contain examples, so include those separately - for jar in `find "$EXAMPLES_DIR/target" -name 'spark-examples*[0-9T].jar'`; do - CLASSPATH="$CLASSPATH:$jar" - done - fi - CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$MLLIB_DIR/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$TOOLS_DIR/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$YARN_DIR/target/scala-$SCALA_VERSION/classes" - for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do - CLASSPATH="$CLASSPATH:$jar" - done - - # Add Scala standard library - if [ -z "$SCALA_LIBRARY_PATH" ]; then - if [ -z "$SCALA_HOME" ]; then - echo "SCALA_HOME is not set" >&2 - exit 1 - fi - SCALA_LIBRARY_PATH="$SCALA_HOME/lib" - fi - CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar" - CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar" - CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar" -} - -function release_classpath { - CLASSPATH="$CLASSPATH:$FWDIR/jars/*" -} - +CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf" if [ -f "$FWDIR/RELEASE" ]; then - release_classpath + ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar` else - dev_classpath + ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*.jar` +fi +CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR" + +# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 +if [[ $SPARK_TESTING == 1 ]]; then + CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes" fi -# Add hadoop conf dir - else FileSystem.*, etc fail ! +# Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! # Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts # the configurtion files. if [ "x" != "x$HADOOP_CONF_DIR" ]; then diff --git a/bin/spark-daemon.sh b/bin/spark-daemon.sh index 96c71e66ca..eac0774669 100755 --- a/bin/spark-daemon.sh +++ b/bin/spark-daemon.sh @@ -87,7 +87,7 @@ TEST_LOG_DIR=$? if [ "${TEST_LOG_DIR}" = "0" ]; then rm -f $SPARK_LOG_DIR/.spark_test else - chown $SPARK_IDENT_STRING $SPARK_LOG_DIR + chown $SPARK_IDENT_STRING $SPARK_LOG_DIR fi if [ "$SPARK_PID_DIR" = "" ]; then @@ -109,7 +109,7 @@ fi case $startStop in (start) - + mkdir -p "$SPARK_PID_DIR" if [ -f $pid ]; then @@ -128,11 +128,11 @@ case $startStop in echo starting $command, logging to $log echo "Spark Daemon: $command" > $log cd "$SPARK_PREFIX" - nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/run $command "$@" >> "$log" 2>&1 < /dev/null & + nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/spark-class $command "$@" >> "$log" 2>&1 < /dev/null & echo $! > $pid sleep 1; head "$log" ;; - + (stop) if [ -f $pid ]; then diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index 6ebbb5ec9b..ebfc21392d 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -125,7 +125,7 @@ private[spark] class CoarseMesosSchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val uri = System.getProperty("spark.executor.uri") if (uri == null) { - val runScript = new File(sparkHome, "run").getCanonicalPath + val runScript = new File(sparkHome, "spark-class").getCanonicalPath command.setValue("\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format( runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) } else { diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala index ed16b9d8ef..553c0309f6 100644 --- a/core/src/test/scala/spark/DriverSuite.scala +++ b/core/src/test/scala/spark/DriverSuite.scala @@ -34,7 +34,7 @@ class DriverSuite extends FunSuite with Timeouts { val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => failAfter(30 seconds) { - Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), + Utils.execute(Seq("./spark-class", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) } } diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index 8a0fa42d94..c526da3ca0 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -158,4 +158,4 @@ trait Message[K] { ## Where to Go from Here -Two example jobs, PageRank and shortest path, are included in `bagel/src/main/scala/spark/bagel/examples`. You can run them by passing the class name to the `run` script included in Spark -- for example, `./run spark.bagel.examples.WikipediaPageRank`. Each example program prints usage help when run without any arguments. +Two example jobs, PageRank and shortest path, are included in `examples/src/main/scala/spark/examples/bagel`. You can run them by passing the class name to the `run-example` script included in Spark -- for example, `./run-example spark.examples.bagel.WikipediaPageRank`. Each example program prints usage help when run without any arguments. diff --git a/docs/index.md b/docs/index.md index 0c4add45dc..e51a6998f6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -27,9 +27,9 @@ Spark also supports building using Maven. If you would like to build using Maven # Testing the Build Spark comes with a number of sample programs in the `examples` directory. -To run one of the samples, use `./run ` in the top-level Spark directory +To run one of the samples, use `./run-example ` in the top-level Spark directory (the `run` script sets up the appropriate paths and launches that program). -For example, `./run spark.examples.SparkPi` will run a sample program that estimates Pi. Each of the +For example, `./run-example spark.examples.SparkPi` will run a sample program that estimates Pi. Each of the examples prints usage help if no params are given. Note that all of the sample programs take a `` parameter specifying the cluster URL diff --git a/docs/java-programming-guide.md b/docs/java-programming-guide.md index ae8257b539..dd19a5f0c9 100644 --- a/docs/java-programming-guide.md +++ b/docs/java-programming-guide.md @@ -190,6 +190,6 @@ We hope to generate documentation with Java-style syntax in the future. Spark includes several sample programs using the Java API in [`examples/src/main/java`](https://github.com/mesos/spark/tree/master/examples/src/main/java/spark/examples). You can run them by passing the class name to the -`run` script included in Spark -- for example, `./run +`run-example` script included in Spark -- for example, `./run-example spark.examples.JavaWordCount`. Each example program prints usage help when run without any arguments. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 1a0afd19d4..678cd57aba 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -15,9 +15,9 @@ We need a consolidated spark core jar (which bundles all the required dependenci This can be built either through sbt or via maven. - Building spark assembled jar via sbt. -Enable YARN support by setting `SPARK_WITH_YARN=true` when invoking sbt: +Enable YARN support by setting `SPARK_YARN=true` when invoking sbt: - SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_WITH_YARN=true ./sbt/sbt clean assembly + SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_YARN=true ./sbt/sbt clean assembly The assembled jar would typically be something like : `./yarn/target/spark-yarn-assembly-0.8.0-SNAPSHOT.jar` @@ -55,7 +55,7 @@ This would be used to connect to the cluster, write to the dfs and submit jobs t The command to launch the YARN Client is as follows: - SPARK_JAR= ./run spark.deploy.yarn.Client \ + SPARK_JAR= ./spark-class spark.deploy.yarn.Client \ --jar \ --class \ --args \ @@ -67,7 +67,7 @@ The command to launch the YARN Client is as follows: For example: - SPARK_JAR=./yarn/target/spark-yarn-assembly-{{site.SPARK_VERSION}}.jar ./run spark.deploy.yarn.Client \ + SPARK_JAR=./yarn/target/spark-yarn-assembly-{{site.SPARK_VERSION}}.jar ./spark-class spark.deploy.yarn.Client \ --jar examples/target/scala-{{site.SCALA_VERSION}}/spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}.jar \ --class spark.examples.SparkPi \ --args yarn-standalone \ diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index e9cf9ef36f..db584d2096 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -356,7 +356,7 @@ res2: Int = 10 # Where to Go from Here You can see some [example Spark programs](http://www.spark-project.org/examples.html) on the Spark website. -In addition, Spark includes several sample programs in `examples/src/main/scala`. Some of them have both Spark versions and local (non-parallel) versions, allowing you to see what had to be changed to make the program run on a cluster. You can run them using by passing the class name to the `run` script included in Spark -- for example, `./run spark.examples.SparkPi`. Each example program prints usage help when run without any arguments. +In addition, Spark includes several sample programs in `examples/src/main/scala`. Some of them have both Spark versions and local (non-parallel) versions, allowing you to see what had to be changed to make the program run on a cluster. You can run them using by passing the class name to the `run-example` script included in Spark -- for example, `./run-example spark.examples.SparkPi`. Each example program prints usage help when run without any arguments. For help on optimizing your program, the [configuration](configuration.html) and [tuning](tuning.html) guides provide information on best practices. They are especially important for diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 7463844a4e..bb8be276c5 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -20,7 +20,7 @@ Compile Spark with `sbt package` as described in the [Getting Started Guide](ind You can start a standalone master server by executing: - ./run spark.deploy.master.Master + ./spark-class spark.deploy.master.Master Once started, the master will print out a `spark://IP:PORT` URL for itself, which you can use to connect workers to it, or pass as the "master" argument to `SparkContext` to connect a job to the cluster. You can also find this URL on @@ -28,7 +28,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by Similarly, you can start one or more workers and connect them to the master via: - ./run spark.deploy.worker.Worker spark://IP:PORT + ./spark-class spark.deploy.worker.Worker spark://IP:PORT Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default). You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index a74c17bdb7..3330e63598 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -234,7 +234,7 @@ $ nc -lk 9999 Then, in a different terminal, you can start NetworkWordCount by using {% highlight bash %} -$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999 +$ ./run-example spark.streaming.examples.NetworkWordCount local[2] localhost 9999 {% endhighlight %} This will make NetworkWordCount connect to the netcat server. Any lines typed in the terminal running the netcat server will be counted and printed on screen. @@ -272,7 +272,7 @@ Time: 1357008430000 ms -You can find more examples in `/streaming/src/main/scala/spark/streaming/examples/`. They can be run in the similar manner using `./run spark.streaming.examples....` . Executing without any parameter would give the required parameter list. Further explanation to run them can be found in comments in the files. +You can find more examples in `/streaming/src/main/scala/spark/streaming/examples/`. They can be run in the similar manner using `./run-example spark.streaming.examples....` . Executing without any parameter would give the required parameter list. Further explanation to run them can be found in comments in the files. # DStream Persistence Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`. diff --git a/examples/pom.xml b/examples/pom.xml index 0db52b8691..d24bd404fa 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -47,6 +47,11 @@ spark-mllib ${project.version} + + org.spark-project + spark-bagel + ${project.version} + org.apache.hbase hbase diff --git a/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala new file mode 100644 index 0000000000..c23ee9895f --- /dev/null +++ b/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.examples.bagel + +import spark._ +import spark.SparkContext._ + +import spark.bagel._ +import spark.bagel.Bagel._ + +import scala.collection.mutable.ArrayBuffer + +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} + +import com.esotericsoftware.kryo._ + +class PageRankUtils extends Serializable { + def computeWithCombiner(numVertices: Long, epsilon: Double)( + self: PRVertex, messageSum: Option[Double], superstep: Int + ): (PRVertex, Array[PRMessage]) = { + val newValue = messageSum match { + case Some(msgSum) if msgSum != 0 => + 0.15 / numVertices + 0.85 * msgSum + case _ => self.value + } + + val terminate = superstep >= 10 + + val outbox: Array[PRMessage] = + if (!terminate) + self.outEdges.map(targetId => + new PRMessage(targetId, newValue / self.outEdges.size)) + else + Array[PRMessage]() + + (new PRVertex(newValue, self.outEdges, !terminate), outbox) + } + + def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) = + computeWithCombiner(numVertices, epsilon)(self, messages match { + case Some(msgs) => Some(msgs.map(_.value).sum) + case None => None + }, superstep) +} + +class PRCombiner extends Combiner[PRMessage, Double] with Serializable { + def createCombiner(msg: PRMessage): Double = + msg.value + def mergeMsg(combiner: Double, msg: PRMessage): Double = + combiner + msg.value + def mergeCombiners(a: Double, b: Double): Double = + a + b +} + +class PRVertex() extends Vertex with Serializable { + var value: Double = _ + var outEdges: Array[String] = _ + var active: Boolean = _ + + def this(value: Double, outEdges: Array[String], active: Boolean = true) { + this() + this.value = value + this.outEdges = outEdges + this.active = active + } + + override def toString(): String = { + "PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString) + } +} + +class PRMessage() extends Message[String] with Serializable { + var targetId: String = _ + var value: Double = _ + + def this(targetId: String, value: Double) { + this() + this.targetId = targetId + this.value = value + } +} + +class PRKryoRegistrator extends KryoRegistrator { + def registerClasses(kryo: Kryo) { + kryo.register(classOf[PRVertex]) + kryo.register(classOf[PRMessage]) + } +} + +class CustomPartitioner(partitions: Int) extends Partitioner { + def numPartitions = partitions + + def getPartition(key: Any): Int = { + val hash = key match { + case k: Long => (k & 0x00000000FFFFFFFFL).toInt + case _ => key.hashCode + } + + val mod = key.hashCode % partitions + if (mod < 0) mod + partitions else mod + } + + override def equals(other: Any): Boolean = other match { + case c: CustomPartitioner => + c.numPartitions == numPartitions + case _ => false + } +} diff --git a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala new file mode 100644 index 0000000000..00635a7ffa --- /dev/null +++ b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.examples.bagel + +import spark._ +import spark.SparkContext._ + +import spark.bagel._ +import spark.bagel.Bagel._ + +import scala.xml.{XML,NodeSeq} + +/** + * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" + * files from there, which contains one line per wiki article in a tab-separated format + * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). + */ +object WikipediaPageRank { + def main(args: Array[String]) { + if (args.length < 5) { + System.err.println("Usage: WikipediaPageRank ") + System.exit(-1) + } + + System.setProperty("spark.serializer", "spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) + + val inputFile = args(0) + val threshold = args(1).toDouble + val numPartitions = args(2).toInt + val host = args(3) + val usePartitioner = args(4).toBoolean + val sc = new SparkContext(host, "WikipediaPageRank") + + // Parse the Wikipedia page data into a graph + val input = sc.textFile(inputFile) + + println("Counting vertices...") + val numVertices = input.count() + println("Done counting vertices.") + + println("Parsing input file...") + var vertices = input.map(line => { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = links.map(link => new String(link.text)).toArray + val id = new String(title) + (id, new PRVertex(1.0 / numVertices, outEdges)) + }) + if (usePartitioner) + vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache + else + vertices = vertices.cache + println("Done parsing input file.") + + // Do the computation + val epsilon = 0.01 / numVertices + val messages = sc.parallelize(Array[(String, PRMessage)]()) + val utils = new PageRankUtils + val result = + Bagel.run( + sc, vertices, messages, combiner = new PRCombiner(), + numPartitions = numPartitions)( + utils.computeWithCombiner(numVertices, epsilon)) + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = + (result + .filter { case (id, vertex) => vertex.value >= threshold } + .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } + .collect.mkString) + println(top) + } +} diff --git a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala new file mode 100644 index 0000000000..c416ddbc58 --- /dev/null +++ b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.examples.bagel + +import spark._ +import serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import spark.SparkContext._ + +import spark.bagel._ +import spark.bagel.Bagel._ + +import scala.xml.{XML,NodeSeq} + +import scala.collection.mutable.ArrayBuffer + +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} +import java.nio.ByteBuffer + +object WikipediaPageRankStandalone { + def main(args: Array[String]) { + if (args.length < 5) { + System.err.println("Usage: WikipediaPageRankStandalone ") + System.exit(-1) + } + + System.setProperty("spark.serializer", "spark.bagel.examples.WPRSerializer") + + val inputFile = args(0) + val threshold = args(1).toDouble + val numIterations = args(2).toInt + val host = args(3) + val usePartitioner = args(4).toBoolean + val sc = new SparkContext(host, "WikipediaPageRankStandalone") + + val input = sc.textFile(inputFile) + val partitioner = new HashPartitioner(sc.defaultParallelism) + val links = + if (usePartitioner) + input.map(parseArticle _).partitionBy(partitioner).cache() + else + input.map(parseArticle _).cache() + val n = links.count() + val defaultRank = 1.0 / n + val a = 0.15 + + // Do the computation + val startTime = System.currentTimeMillis + val ranks = + pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, sc.defaultParallelism) + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = + (ranks + .filter { case (id, rank) => rank >= threshold } + .map { case (id, rank) => "%s\t%s\n".format(id, rank) } + .collect().mkString) + println(top) + + val time = (System.currentTimeMillis - startTime) / 1000.0 + println("Completed %d iterations in %f seconds: %f seconds per iteration" + .format(numIterations, time, time / numIterations)) + System.exit(0) + } + + def parseArticle(line: String): (String, Array[String]) = { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val id = new String(title) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = links.map(link => new String(link.text)).toArray + (id, outEdges) + } + + def pageRank( + links: RDD[(String, Array[String])], + numIterations: Int, + defaultRank: Double, + a: Double, + n: Long, + partitioner: Partitioner, + usePartitioner: Boolean, + numPartitions: Int + ): RDD[(String, Double)] = { + var ranks = links.mapValues { edges => defaultRank } + for (i <- 1 to numIterations) { + val contribs = links.groupWith(ranks).flatMap { + case (id, (linksWrapper, rankWrapper)) => + if (linksWrapper.length > 0) { + if (rankWrapper.length > 0) { + linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size)) + } else { + linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size)) + } + } else { + Array[(String, Double)]() + } + } + ranks = (contribs.combineByKey((x: Double) => x, + (x: Double, y: Double) => x + y, + (x: Double, y: Double) => x + y, + partitioner) + .mapValues(sum => a/n + (1-a)*sum)) + } + ranks + } +} + +class WPRSerializer extends spark.serializer.Serializer { + def newInstance(): SerializerInstance = new WPRSerializerInstance() +} + +class WPRSerializerInstance extends SerializerInstance { + def serialize[T](t: T): ByteBuffer = { + throw new UnsupportedOperationException() + } + + def deserialize[T](bytes: ByteBuffer): T = { + throw new UnsupportedOperationException() + } + + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + throw new UnsupportedOperationException() + } + + def serializeStream(s: OutputStream): SerializationStream = { + new WPRSerializationStream(s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new WPRDeserializationStream(s) + } +} + +class WPRSerializationStream(os: OutputStream) extends SerializationStream { + val dos = new DataOutputStream(os) + + def writeObject[T](t: T): SerializationStream = t match { + case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { + case links: Array[String] => { + dos.writeInt(0) // links + dos.writeUTF(id) + dos.writeInt(links.length) + for (link <- links) { + dos.writeUTF(link) + } + this + } + case rank: Double => { + dos.writeInt(1) // rank + dos.writeUTF(id) + dos.writeDouble(rank) + this + } + } + case (id: String, rank: Double) => { + dos.writeInt(2) // rank without wrapper + dos.writeUTF(id) + dos.writeDouble(rank) + this + } + } + + def flush() { dos.flush() } + def close() { dos.close() } +} + +class WPRDeserializationStream(is: InputStream) extends DeserializationStream { + val dis = new DataInputStream(is) + + def readObject[T](): T = { + val typeId = dis.readInt() + typeId match { + case 0 => { + val id = dis.readUTF() + val numLinks = dis.readInt() + val links = new Array[String](numLinks) + for (i <- 0 until numLinks) { + val link = dis.readUTF() + links(i) = link + } + (id, ArrayBuffer(links)).asInstanceOf[T] + } + case 1 => { + val id = dis.readUTF() + val rank = dis.readDouble() + (id, ArrayBuffer(rank)).asInstanceOf[T] + } + case 2 => { + val id = dis.readUTF() + val rank = dis.readDouble() + (id, rank).asInstanceOf[T] + } + } + } + + def close() { dis.close() } +} diff --git a/make-distribution.sh b/make-distribution.sh index 70aff418c7..df7bbf1e74 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -47,7 +47,7 @@ VERSION=$($FWDIR/sbt/sbt "show version" | tail -1 | cut -f 2 | sed 's/^\([a-zA-Z # Initialize defaults SPARK_HADOOP_VERSION=1.0.4 -SPARK_WITH_YARN=false +SPARK_YARN=false MAKE_TGZ=false # Parse arguments @@ -58,7 +58,7 @@ while (( "$#" )); do shift ;; --with-yarn) - SPARK_WITH_YARN=true + SPARK_YARN=true ;; --tgz) MAKE_TGZ=true @@ -74,7 +74,7 @@ else fi echo "Hadoop version set to $SPARK_HADOOP_VERSION" -if [ "$SPARK_WITH_YARN" == "true" ]; then +if [ "$SPARK_YARN" == "true" ]; then echo "YARN enabled" else echo "YARN disabled" @@ -82,21 +82,22 @@ fi # Build fat JAR export SPARK_HADOOP_VERSION -export SPARK_WITH_YARN -"$FWDIR/sbt/sbt" "repl/assembly" +export SPARK_YARN +"$FWDIR/sbt/sbt" "assembly/assembly" # Make directories rm -rf "$DISTDIR" mkdir -p "$DISTDIR/jars" -echo "$VERSION" > "$DISTDIR/RELEASE" +echo "Spark $VERSION built for Hadoop $SPARK_HADOOP_VERSION" > "$DISTDIR/RELEASE" # Copy jars -cp $FWDIR/repl/target/*.jar "$DISTDIR/jars/" +cp $FWDIR/assembly/target/*/*assembly*.jar "$DISTDIR/jars/" # Copy other things cp -r "$FWDIR/bin" "$DISTDIR" cp -r "$FWDIR/conf" "$DISTDIR" -cp "$FWDIR/run" "$FWDIR/spark-shell" "$DISTDIR" +cp "$FWDIR/spark-class" "$DISTDIR" +cp "$FWDIR/spark-shell" "$DISTDIR" cp "$FWDIR/spark-executor" "$DISTDIR" diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5fdcf19b62..b3bf3ef89b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -26,30 +26,35 @@ import AssemblyKeys._ object SparkBuild extends Build { // Hadoop version to build against. For example, "1.0.4" for Apache releases, or // "2.0.0-mr1-cdh4.2.0" for Cloudera Hadoop. Note that these variables can be set - // through the environment variables SPARK_HADOOP_VERSION and SPARK_WITH_YARN. + // through the environment variables SPARK_HADOOP_VERSION and SPARK_YARN. val DEFAULT_HADOOP_VERSION = "1.0.4" - val DEFAULT_WITH_YARN = false + val DEFAULT_YARN = false // HBase version; set as appropriate. val HBASE_VERSION = "0.94.6" - lazy val root = Project("root", file("."), settings = rootSettings) aggregate(allProjects:_*) + lazy val root = Project("root", file("."), settings = rootSettings) aggregate(allProjects: _*) lazy val core = Project("core", file("core"), settings = coreSettings) - lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn(core) dependsOn(bagel) dependsOn(mllib) dependsOn(maybeYarn:_*) + lazy val repl = Project("repl", file("repl"), settings = replSettings) + .dependsOn(core, bagel, mllib) dependsOn(maybeYarn: _*) - lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming) dependsOn(mllib) + lazy val examples = Project("examples", file("examples"), settings = examplesSettings) + .dependsOn(core, mllib, bagel, streaming) - lazy val tools = Project("tools", file("tools"), settings = examplesSettings) dependsOn (core) dependsOn (streaming) + lazy val tools = Project("tools", file("tools"), settings = toolsSettings) dependsOn(core) dependsOn(streaming) - lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn (core) + lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn(core) - lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn (core) + lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn(core) - lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn (core) + lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn(core) - lazy val yarn = Project("yarn", file("yarn"), settings = yarnSettings) dependsOn (core) + lazy val yarn = Project("yarn", file("yarn"), settings = yarnSettings) dependsOn(core) + + lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings) + .dependsOn(core, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*) // A configuration to set an alternative publishLocalConfiguration lazy val MavenCompile = config("m2r") extend(Compile) @@ -57,15 +62,16 @@ object SparkBuild extends Build { // Allows build configuration to be set through environment variables lazy val hadoopVersion = scala.util.Properties.envOrElse("SPARK_HADOOP_VERSION", DEFAULT_HADOOP_VERSION) - lazy val isYarnEnabled = scala.util.Properties.envOrNone("SPARK_WITH_YARN") match { - case None => DEFAULT_WITH_YARN + lazy val isYarnEnabled = scala.util.Properties.envOrNone("SPARK_YARN") match { + case None => DEFAULT_YARN case Some(v) => v.toBoolean } // Conditionally include the yarn sub-project lazy val maybeYarn = if(isYarnEnabled) Seq[ClasspathDependency](yarn) else Seq[ClasspathDependency]() lazy val maybeYarnRef = if(isYarnEnabled) Seq[ProjectReference](yarn) else Seq[ProjectReference]() - lazy val allProjects = Seq[ProjectReference](core, repl, examples, bagel, streaming, mllib, tools) ++ maybeYarnRef + lazy val allProjects = Seq[ProjectReference]( + core, repl, examples, bagel, streaming, mllib, tools, assemblyProj) ++ maybeYarnRef def sharedSettings = Defaults.defaultSettings ++ Seq( organization := "org.spark-project", @@ -100,8 +106,8 @@ object SparkBuild extends Build { http://spark-project.org/ - BSD License - https://github.com/mesos/spark/blob/master/LICENSE + Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html repo @@ -195,7 +201,7 @@ object SparkBuild extends Build { "com.twitter" % "chill_2.9.3" % "0.3.1", "com.twitter" % "chill-java" % "0.3.1" ) - ) ++ assemblySettings ++ extraAssemblySettings + ) def rootSettings = sharedSettings ++ Seq( publish := {} @@ -204,7 +210,7 @@ object SparkBuild extends Build { def replSettings = sharedSettings ++ Seq( name := "spark-repl", libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-compiler" % _) - ) ++ assemblySettings ++ extraAssemblySettings + ) def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", @@ -223,7 +229,7 @@ object SparkBuild extends Build { exclude("org.apache.cassandra.deps", "avro") excludeAll(excludeSnappy) ) - ) + ) ++ assemblySettings ++ extraAssemblySettings def toolsSettings = sharedSettings ++ Seq( name := "spark-tools" @@ -251,7 +257,7 @@ object SparkBuild extends Build { "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty), "com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty) ) - ) ++ assemblySettings ++ extraAssemblySettings + ) def yarnSettings = sharedSettings ++ Seq( name := "spark-yarn" @@ -271,7 +277,13 @@ object SparkBuild extends Build { ) ) - def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( + def assemblyProjSettings = sharedSettings ++ Seq( + name := "spark-assembly", + jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" } + ) ++ assemblySettings ++ extraAssemblySettings + + def extraAssemblySettings() = Seq( + test in assembly := {}, mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard diff --git a/project/build.properties b/project/build.properties index 08e17131f6..9647277162 100644 --- a/project/build.properties +++ b/project/build.properties @@ -15,4 +15,4 @@ # limitations under the License. # -sbt.version=0.12.3 +sbt.version=0.12.4 diff --git a/project/plugins.sbt b/project/plugins.sbt index 783b40d4f5..cfcd85082a 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -4,7 +4,7 @@ resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/release resolvers += "Spray Repository" at "http://repo.spray.cc/" -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.8.5") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.9.1") addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index e503fb7621..18011c0dc9 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -28,7 +28,7 @@ SPARK_HOME = os.environ["SPARK_HOME"] def launch_gateway(): # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and SPARK_MEM settings from spark-env.sh - command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer", + command = [os.path.join(SPARK_HOME, "spark-class"), "py4j.GatewayServer", "--die-on-broken-pipe", "0"] proc = Popen(command, stdout=PIPE, stdin=PIPE) # Determine which ephemeral port the server started on: diff --git a/run b/run deleted file mode 100755 index 715bbf93d5..0000000000 --- a/run +++ /dev/null @@ -1,175 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -SCALA_VERSION=2.9.3 - -# Figure out where the Scala framework is installed -FWDIR="$(cd `dirname $0`; pwd)" - -# Export this as SPARK_HOME -export SPARK_HOME="$FWDIR" - -# Load environment variables from conf/spark-env.sh, if it exists -if [ -e $FWDIR/conf/spark-env.sh ] ; then - . $FWDIR/conf/spark-env.sh -fi - -if [ -z "$1" ]; then - echo "Usage: run []" >&2 - exit 1 -fi - -# If this is a standalone cluster daemon, reset SPARK_JAVA_OPTS and SPARK_MEM to reasonable -# values for that; it doesn't need a lot -if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then - SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m} - SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true" - # Do not overwrite SPARK_JAVA_OPTS environment variable in this script - OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS" # Empty by default -else - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS" -fi - - -# Add java opts for master, worker, executor. The opts maybe null -case "$1" in - 'spark.deploy.master.Master') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS" - ;; - 'spark.deploy.worker.Worker') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS" - ;; - 'spark.executor.StandaloneExecutorBackend') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS" - ;; - 'spark.executor.MesosExecutorBackend') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS" - ;; - 'spark.repl.Main') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS" - ;; -esac - -# Figure out whether to run our class with java or with the scala launcher. -# In most cases, we'd prefer to execute our process with java because scala -# creates a shell script as the parent of its Java process, which makes it -# hard to kill the child with stuff like Process.destroy(). However, for -# the Spark shell, the wrapper is necessary to properly reset the terminal -# when we exit, so we allow it to set a variable to launch with scala. -# We still fall back on java for the shell if this is a "release" created -# from make-distribution.sh since it's possible scala is not installed -# but we have everything we need to run the shell. -if [[ "$SPARK_LAUNCH_WITH_SCALA" == "1" && ! -f "$FWDIR/RELEASE" ]]; then - if [ "$SCALA_HOME" ]; then - RUNNER="${SCALA_HOME}/bin/scala" - else - if [ `command -v scala` ]; then - RUNNER="scala" - else - echo "SCALA_HOME is not set and scala is not in PATH" >&2 - exit 1 - fi - fi -else - if [ -n "${JAVA_HOME}" ]; then - RUNNER="${JAVA_HOME}/bin/java" - else - if [ `command -v java` ]; then - RUNNER="java" - else - echo "JAVA_HOME is not set" >&2 - exit 1 - fi - fi - if [[ ! -f "$FWDIR/RELEASE" && -z "$SCALA_LIBRARY_PATH" ]]; then - if [ -z "$SCALA_HOME" ]; then - echo "SCALA_HOME is not set" >&2 - exit 1 - fi - SCALA_LIBRARY_PATH="$SCALA_HOME/lib" - fi -fi - -# Figure out how much memory to use per executor and set it as an environment -# variable so that our process sees it and can report it to Mesos -if [ -z "$SPARK_MEM" ] ; then - SPARK_MEM="512m" -fi -export SPARK_MEM - -# Set JAVA_OPTS to be able to load native libraries and to set heap size -JAVA_OPTS="$OUR_JAVA_OPTS" -JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH" -JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM" -# Load extra JAVA_OPTS from conf/java-opts, if it exists -if [ -e $FWDIR/conf/java-opts ] ; then - JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`" -fi -export JAVA_OPTS -# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala! - -if [ ! -f "$FWDIR/RELEASE" ]; then - CORE_DIR="$FWDIR/core" - EXAMPLES_DIR="$FWDIR/examples" - REPL_DIR="$FWDIR/repl" - - # Exit if the user hasn't compiled Spark - if [ ! -e "$CORE_DIR/target" ]; then - echo "Failed to find Spark classes in $CORE_DIR/target" >&2 - echo "You need to compile Spark before running this program" >&2 - exit 1 - fi - - if [[ "$@" = *repl* && ! -e "$REPL_DIR/target" ]]; then - echo "Failed to find Spark classes in $REPL_DIR/target" >&2 - echo "You need to compile Spark repl module before running this program" >&2 - exit 1 - fi - - # Figure out the JAR file that our examples were packaged into. This includes a bit of a hack - # to avoid the -sources and -doc packages that are built by publish-local. - if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ]; then - # Use the JAR from the SBT build - export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar` - fi - if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then - # Use the JAR from the Maven build - export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` - fi -fi - -# Compute classpath using external script -CLASSPATH=`$FWDIR/bin/compute-classpath.sh` -export CLASSPATH - -if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then - EXTRA_ARGS="" # Java options will be passed to scala as JAVA_OPTS -else - # The JVM doesn't read JAVA_OPTS by default so we need to pass it in - EXTRA_ARGS="$JAVA_OPTS" -fi - -command="$RUNNER -cp \"$CLASSPATH\" $EXTRA_ARGS $@" -if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then - echo "Spark Command: $command" - echo "========================================" - echo -fi - -exec "$RUNNER" -cp "$CLASSPATH" $EXTRA_ARGS "$@" diff --git a/run-example b/run-example new file mode 100755 index 0000000000..e1b26257e1 --- /dev/null +++ b/run-example @@ -0,0 +1,76 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCALA_VERSION=2.9.3 + +# Figure out where the Scala framework is installed +FWDIR="$(cd `dirname $0`; pwd)" + +# Export this as SPARK_HOME +export SPARK_HOME="$FWDIR" + +# Load environment variables from conf/spark-env.sh, if it exists +if [ -e $FWDIR/conf/spark-env.sh ] ; then + . $FWDIR/conf/spark-env.sh +fi + +if [ -z "$1" ]; then + echo "Usage: run-example []" >&2 + exit 1 +fi + +# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack +# to avoid the -sources and -doc packages that are built by publish-local. +EXAMPLES_DIR="$FWDIR"/examples +SPARK_EXAMPLES_JAR="" +if [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9T].jar ]; then + # Use the JAR from the SBT build + export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9T].jar` +fi +if [ -e "$EXAMPLES_DIR"/target/spark-examples*[0-9T].jar ]; then + # Use the JAR from the Maven build + # TODO: this also needs to become an assembly! + export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR"/target/spark-examples*[0-9T].jar` +fi +if [[ -z $SPARK_EXAMPLES_JAR ]]; then + echo "Failed to find Spark examples assembly in $FWDIR/examples/target" >&2 + echo "You need to compile Spark before running this program" >&2 + exit 1 +fi + +# Find java binary +if [ -n "${JAVA_HOME}" ]; then + RUNNER="${JAVA_HOME}/bin/java" +else + if [ `command -v java` ]; then + RUNNER="java" + else + echo "JAVA_HOME is not set" >&2 + exit 1 + fi +fi + +if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then + echo -n "Spark Command: " + echo "$RUNNER" -cp "$SPARK_EXAMPLES_JAR" "$@" + echo "========================================" + echo +fi + +exec "$RUNNER" -cp "$SPARK_EXAMPLES_JAR" "$@" diff --git a/spark-class b/spark-class new file mode 100755 index 0000000000..5ef3de9773 --- /dev/null +++ b/spark-class @@ -0,0 +1,124 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCALA_VERSION=2.9.3 + +# Figure out where the Scala framework is installed +FWDIR="$(cd `dirname $0`; pwd)" + +# Export this as SPARK_HOME +export SPARK_HOME="$FWDIR" + +# Load environment variables from conf/spark-env.sh, if it exists +if [ -e $FWDIR/conf/spark-env.sh ] ; then + . $FWDIR/conf/spark-env.sh +fi + +if [ -z "$1" ]; then + echo "Usage: run []" >&2 + exit 1 +fi + +# If this is a standalone cluster daemon, reset SPARK_JAVA_OPTS and SPARK_MEM to reasonable +# values for that; it doesn't need a lot +if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then + SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m} + SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true" + # Do not overwrite SPARK_JAVA_OPTS environment variable in this script + OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS" # Empty by default +else + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS" +fi + + +# Add java opts for master, worker, executor. The opts maybe null +case "$1" in + 'spark.deploy.master.Master') + OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS" + ;; + 'spark.deploy.worker.Worker') + OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS" + ;; + 'spark.executor.StandaloneExecutorBackend') + OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS" + ;; + 'spark.executor.MesosExecutorBackend') + OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS" + ;; + 'spark.repl.Main') + OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS" + ;; +esac + +# Find the java binary +if [ -n "${JAVA_HOME}" ]; then + RUNNER="${JAVA_HOME}/bin/java" +else + if [ `command -v java` ]; then + RUNNER="java" + else + echo "JAVA_HOME is not set" >&2 + exit 1 + fi +fi +if [[ ! -f "$FWDIR/RELEASE" && -z "$SCALA_LIBRARY_PATH" ]]; then + if [ -z "$SCALA_HOME" ]; then + echo "SCALA_HOME is not set" >&2 + exit 1 + fi + SCALA_LIBRARY_PATH="$SCALA_HOME/lib" +fi + +# Set SPARK_MEM if it isn't already set since we also use it for this process +SPARK_MEM=${SPARK_MEM:-512m} +export SPARK_MEM + +# Set JAVA_OPTS to be able to load native libraries and to set heap size +JAVA_OPTS="$OUR_JAVA_OPTS" +JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH" +JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM" +# Load extra JAVA_OPTS from conf/java-opts, if it exists +if [ -e $FWDIR/conf/java-opts ] ; then + JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`" +fi +export JAVA_OPTS +# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala! + +if [ ! -f "$FWDIR/RELEASE" ]; then + # Exit if the user hasn't compiled Spark + ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*.jar >& /dev/null + if [[ $? != 0 ]]; then + echo "Failed to find Spark assembly in $FWDIR/assembly/target" >&2 + echo "You need to compile Spark before running this program" >&2 + exit 1 + fi +fi + +# Compute classpath using external script +CLASSPATH=`$FWDIR/bin/compute-classpath.sh` +export CLASSPATH + +if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then + echo -n "Spark Command: " + echo "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" + echo "========================================" + echo +fi + +exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" diff --git a/spark-executor b/spark-executor index feccbf5cc2..63692bd46c 100755 --- a/spark-executor +++ b/spark-executor @@ -19,4 +19,4 @@ FWDIR="`dirname $0`" echo "Running spark-executor with framework dir = $FWDIR" -exec $FWDIR/run spark.executor.MesosExecutorBackend +exec $FWDIR/spark-class spark.executor.MesosExecutorBackend diff --git a/spark-shell b/spark-shell index 62fc18550d..4d379c5cfb 100755 --- a/spark-shell +++ b/spark-shell @@ -79,8 +79,7 @@ if [[ ! $? ]]; then saved_stty="" fi -export SPARK_LAUNCH_WITH_SCALA=${SPARK_LAUNCH_WITH_SCALA:-1} -$FWDIR/run $OPTIONS spark.repl.Main "$@" +$FWDIR/spark-class $OPTIONS spark.repl.Main "$@" # record the exit status lest it be overwritten: # then reenable echo and propagate the code. -- cgit v1.2.3 From 2de756ff195e580007d5d96d49fa27634e04c765 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 27 Aug 2013 19:39:54 -0700 Subject: Update some build instructions because only sbt assembly and mvn package are now needed --- README.md | 10 +++++----- docs/building-with-maven.md | 14 +++++++------- docs/index.md | 2 +- docs/python-programming-guide.md | 2 +- docs/quick-start.md | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) (limited to 'README.md') diff --git a/README.md b/README.md index e7af34f513..89b5a0abfd 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Spark requires Scala 2.9.3 (Scala 2.10 is not yet supported). The project is built using Simple Build Tool (SBT), which is packaged with it. To build Spark and its example programs, run: - sbt/sbt package assembly + sbt/sbt assembly Spark also supports building using Maven. If you would like to build using Maven, see the [instructions for building Spark with Maven](http://spark-project.org/docs/latest/building-with-maven.html) @@ -52,19 +52,19 @@ For Apache Hadoop versions 1.x, Cloudera CDH MRv1, and other Hadoop versions without YARN, use: # Apache Hadoop 1.2.1 - $ SPARK_HADOOP_VERSION=1.2.1 sbt/sbt package assembly + $ SPARK_HADOOP_VERSION=1.2.1 sbt/sbt assembly # Cloudera CDH 4.2.0 with MapReduce v1 - $ SPARK_HADOOP_VERSION=2.0.0-mr1-cdh4.2.0 sbt/sbt package assembly + $ SPARK_HADOOP_VERSION=2.0.0-mr1-cdh4.2.0 sbt/sbt assembly For Apache Hadoop 2.x, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions with YARN, also set `SPARK_WITH_YARN=true`: # Apache Hadoop 2.0.5-alpha - $ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_WITH_YARN=true sbt/sbt package assembly + $ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_WITH_YARN=true sbt/sbt assembly # Cloudera CDH 4.2.0 with MapReduce v2 - $ SPARK_HADOOP_VERSION=2.0.0-cdh4.2.0 SPARK_WITH_YARN=true sbt/sbt package assembly + $ SPARK_HADOOP_VERSION=2.0.0-cdh4.2.0 SPARK_WITH_YARN=true sbt/sbt assembly For convenience, these variables may also be set through the `conf/spark-env.sh` file described below. diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index a9f2cb8a7a..72d37fec0a 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -15,18 +15,18 @@ To enable support for HDFS and other Hadoop-supported storage systems, specify t For Apache Hadoop versions 1.x, Cloudera CDH MRv1, and other Hadoop versions without YARN, use: # Apache Hadoop 1.2.1 - $ mvn -Dhadoop.version=1.2.1 clean install + $ mvn -Dhadoop.version=1.2.1 clean package # Cloudera CDH 4.2.0 with MapReduce v1 - $ mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 clean install + $ mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 clean package For Apache Hadoop 2.x, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions with YARN, enable the "hadoop2-yarn" profile: # Apache Hadoop 2.0.5-alpha - $ mvn -Phadoop2-yarn -Dhadoop.version=2.0.5-alpha clean install + $ mvn -Phadoop2-yarn -Dhadoop.version=2.0.5-alpha clean package # Cloudera CDH 4.2.0 with MapReduce v2 - $ mvn -Phadoop2-yarn -Dhadoop.version=2.0.0-cdh4.2.0 clean install + $ mvn -Phadoop2-yarn -Dhadoop.version=2.0.0-cdh4.2.0 clean package ## Spark Tests in Maven ## @@ -35,7 +35,7 @@ Tests are run by default via the scalatest-maven-plugin. With this you can do th Skip test execution (but not compilation): - $ mvn -Dhadoop.version=... -DskipTests clean install + $ mvn -Dhadoop.version=... -DskipTests clean package To run a specific test suite: @@ -72,8 +72,8 @@ This setup works fine in IntelliJ IDEA 11.1.4. After opening the project via the ## Building Spark Debian Packages ## -It includes support for building a Debian package containing a 'fat-jar' which includes the repl, the examples and bagel. This can be created by specifying the deb profile: +It includes support for building a Debian package containing a 'fat-jar' which includes the repl, the examples and bagel. This can be created by specifying the following profiles: - $ mvn -Pdeb clean install + $ mvn -Prepl-bin -Pdeb clean package The debian package can then be found under repl/target. We added the short commit hash to the file name so that we can distinguish individual packages build for SNAPSHOT versions. diff --git a/docs/index.md b/docs/index.md index e51a6998f6..ec9c7dd4f3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,7 +20,7 @@ of these methods on slave nodes on your cluster. Spark uses [Simple Build Tool](http://www.scala-sbt.org), which is bundled with it. To compile the code, go into the top-level Spark directory and run - sbt/sbt package + sbt/sbt assembly Spark also supports building using Maven. If you would like to build using Maven, see the [instructions for building Spark with Maven](building-with-maven.html). diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 794bff5647..15d3ebfcae 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -70,7 +70,7 @@ The script automatically adds the `pyspark` package to the `PYTHONPATH`. The `pyspark` script launches a Python interpreter that is configured to run PySpark jobs. To use `pyspark` interactively, first build Spark, then launch it directly from the command line without any options: {% highlight bash %} -$ sbt/sbt package +$ sbt/sbt assembly $ ./pyspark {% endhighlight %} diff --git a/docs/quick-start.md b/docs/quick-start.md index 335643536a..4e9deadbaa 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -12,7 +12,7 @@ See the [programming guide](scala-programming-guide.html) for a more complete re To follow along with this guide, you only need to have successfully built Spark on one machine. Simply go into your Spark directory and run: {% highlight bash %} -$ sbt/sbt package +$ sbt/sbt assembly {% endhighlight %} # Interactive Analysis with the Spark Shell -- cgit v1.2.3 From 2c5a4b89ee034b7933b258cfc37bc6d91a06c186 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 31 Aug 2013 18:08:05 -0700 Subject: Small fixes to README --- README.md | 42 ++++++++++++++++-------------------------- 1 file changed, 16 insertions(+), 26 deletions(-) (limited to 'README.md') diff --git a/README.md b/README.md index 89b5a0abfd..2ddfe862a2 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ -# Spark +# Apache Spark -Lightning-Fast Cluster Computing - +Lightning-Fast Cluster Computing - ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the project webpage at . +guide, on the project webpage at . This README file only contains basic setup instructions. @@ -18,16 +18,14 @@ Spark and its example programs, run: sbt/sbt assembly -Spark also supports building using Maven. If you would like to build using Maven, -see the [instructions for building Spark with Maven](http://spark-project.org/docs/latest/building-with-maven.html) -in the Spark documentation.. +Once you've built Spark, the easiest way to start using it is the shell: -To run Spark, you will need to have Scala's bin directory in your `PATH`, or -you will need to set the `SCALA_HOME` environment variable to point to where -you've installed Scala. Scala must be accessible through one of these -methods on your cluster's worker nodes as well as its master. + ./spark-shell -To run one of the examples, use `./run-example `. For example: +Or, for the Python API, the Python shell (`./pyspark`). + +Spark also comes with several sample programs in the `examples` directory. +To run one of them, use `./run-example `. For example: ./run-example spark.examples.SparkLR local[2] @@ -35,7 +33,7 @@ will run the Logistic Regression example locally on 2 CPUs. Each of the example programs prints usage help if no params are given. -All of the Spark samples take a `` parameter that is the cluster URL +All of the Spark samples take a `` parameter that is the cluster URL to connect to. This can be a mesos:// or spark:// URL, or "local" to run locally with one thread, or "local[N]" to run locally with N threads. @@ -58,13 +56,13 @@ versions without YARN, use: $ SPARK_HADOOP_VERSION=2.0.0-mr1-cdh4.2.0 sbt/sbt assembly For Apache Hadoop 2.x, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions -with YARN, also set `SPARK_WITH_YARN=true`: +with YARN, also set `SPARK_YARN=true`: # Apache Hadoop 2.0.5-alpha - $ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_WITH_YARN=true sbt/sbt assembly + $ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_YARN=true sbt/sbt assembly # Cloudera CDH 4.2.0 with MapReduce v2 - $ SPARK_HADOOP_VERSION=2.0.0-cdh4.2.0 SPARK_WITH_YARN=true sbt/sbt assembly + $ SPARK_HADOOP_VERSION=2.0.0-cdh4.2.0 SPARK_YARN=true sbt/sbt assembly For convenience, these variables may also be set through the `conf/spark-env.sh` file described below. @@ -81,22 +79,14 @@ If your project is built with Maven, add this to your POM file's ` org.apache.hadoop hadoop-client - - [1.2.1] + 1.2.1 ## Configuration -Please refer to the "Configuration" guide in the online documentation for a -full overview on how to configure Spark. At the minimum, you will need to -create a `conf/spark-env.sh` script (copy `conf/spark-env.sh.template`) and -set the following two variables: - -- `SCALA_HOME`: Location where Scala is installed. - -- `MESOS_NATIVE_LIBRARY`: Your Mesos library (only needed if you want to run - on Mesos). For example, this might be `/usr/local/lib/libmesos.so` on Linux. +Please refer to the [Configuration guide](http://spark.incubator.apache.org/docs/latest/configuration.html) +in the online documentation for an overview on how to configure Spark. ## Contributing to Spark -- cgit v1.2.3 From 46eecd110a4017ea0c86cbb1010d0ccd6a5eb2ef Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 31 Aug 2013 19:27:07 -0700 Subject: Initial work to rename package to org.apache.spark --- README.md | 2 +- assembly/pom.xml | 16 +- assembly/src/main/assembly/assembly.xml | 10 +- bagel/pom.xml | 6 +- .../main/scala/org/apache/spark/bagel/Bagel.scala | 293 +++++ bagel/src/main/scala/spark/bagel/Bagel.scala | 294 ----- bagel/src/test/scala/bagel/BagelSuite.scala | 118 -- .../scala/org/apache/spark/bagel/BagelSuite.scala | 116 ++ bin/start-master.sh | 2 +- bin/start-slave.sh | 2 +- bin/stop-master.sh | 2 +- bin/stop-slaves.sh | 4 +- core/pom.xml | 4 +- .../org/apache/spark/network/netty/FileClient.java | 89 ++ .../netty/FileClientChannelInitializer.java | 41 + .../spark/network/netty/FileClientHandler.java | 60 + .../org/apache/spark/network/netty/FileServer.java | 103 ++ .../netty/FileServerChannelInitializer.java | 42 + .../spark/network/netty/FileServerHandler.java | 82 ++ .../apache/spark/network/netty/PathResolver.java | 29 + .../main/java/spark/network/netty/FileClient.java | 89 -- .../netty/FileClientChannelInitializer.java | 41 - .../spark/network/netty/FileClientHandler.java | 60 - .../main/java/spark/network/netty/FileServer.java | 103 -- .../netty/FileServerChannelInitializer.java | 42 - .../spark/network/netty/FileServerHandler.java | 82 -- .../java/spark/network/netty/PathResolver.java | 29 - .../org/apache/spark/ui/static/bootstrap.min.css | 874 +++++++++++++ .../org/apache/spark/ui/static/sorttable.js | 495 ++++++++ .../spark/ui/static/spark-logo-77x50px-hd.png | Bin 0 -> 3536 bytes .../org/apache/spark/ui/static/spark_logo.png | Bin 0 -> 14233 bytes .../resources/org/apache/spark/ui/static/webui.css | 63 + .../resources/spark/ui/static/bootstrap.min.css | 874 ------------- .../main/resources/spark/ui/static/sorttable.js | 495 -------- .../spark/ui/static/spark-logo-77x50px-hd.png | Bin 3536 -> 0 bytes .../main/resources/spark/ui/static/spark_logo.png | Bin 14233 -> 0 bytes core/src/main/resources/spark/ui/static/webui.css | 63 - .../main/scala/org/apache/spark/Accumulators.scala | 256 ++++ .../main/scala/org/apache/spark/Aggregator.scala | 61 + .../apache/spark/BlockStoreShuffleFetcher.scala | 89 ++ .../main/scala/org/apache/spark/CacheManager.scala | 82 ++ .../scala/org/apache/spark/ClosureCleaner.scala | 231 ++++ .../main/scala/org/apache/spark/Dependency.scala | 81 ++ .../org/apache/spark/DoubleRDDFunctions.scala | 78 ++ .../org/apache/spark/FetchFailedException.scala | 44 + .../scala/org/apache/spark/HttpFileServer.scala | 62 + .../main/scala/org/apache/spark/HttpServer.scala | 88 ++ .../scala/org/apache/spark/JavaSerializer.scala | 83 ++ .../scala/org/apache/spark/KryoSerializer.scala | 156 +++ core/src/main/scala/org/apache/spark/Logging.scala | 95 ++ .../scala/org/apache/spark/MapOutputTracker.scala | 338 +++++ .../scala/org/apache/spark/PairRDDFunctions.scala | 703 +++++++++++ .../main/scala/org/apache/spark/Partition.scala | 31 + .../main/scala/org/apache/spark/Partitioner.scala | 135 ++ core/src/main/scala/org/apache/spark/RDD.scala | 957 ++++++++++++++ .../scala/org/apache/spark/RDDCheckpointData.scala | 130 ++ .../apache/spark/SequenceFileRDDFunctions.scala | 107 ++ .../org/apache/spark/SerializableWritable.scala | 42 + .../scala/org/apache/spark/ShuffleFetcher.scala | 35 + .../scala/org/apache/spark/SizeEstimator.scala | 283 +++++ .../main/scala/org/apache/spark/SparkContext.scala | 995 +++++++++++++++ .../src/main/scala/org/apache/spark/SparkEnv.scala | 240 ++++ .../scala/org/apache/spark/SparkException.scala | 24 + .../main/scala/org/apache/spark/SparkFiles.java | 42 + .../scala/org/apache/spark/SparkHadoopWriter.scala | 201 +++ .../main/scala/org/apache/spark/TaskContext.scala | 41 + .../scala/org/apache/spark/TaskEndReason.scala | 51 + .../main/scala/org/apache/spark/TaskState.scala | 51 + core/src/main/scala/org/apache/spark/Utils.scala | 780 ++++++++++++ .../org/apache/spark/api/java/JavaDoubleRDD.scala | 167 +++ .../org/apache/spark/api/java/JavaPairRDD.scala | 601 +++++++++ .../scala/org/apache/spark/api/java/JavaRDD.scala | 114 ++ .../org/apache/spark/api/java/JavaRDDLike.scala | 426 +++++++ .../apache/spark/api/java/JavaSparkContext.scala | 418 +++++++ .../java/JavaSparkContextVarargsWorkaround.java | 64 + .../org/apache/spark/api/java/JavaUtils.scala | 28 + .../org/apache/spark/api/java/StorageLevels.java | 48 + .../api/java/function/DoubleFlatMapFunction.java | 37 + .../spark/api/java/function/DoubleFunction.java | 34 + .../spark/api/java/function/FlatMapFunction.scala | 28 + .../spark/api/java/function/FlatMapFunction2.scala | 28 + .../apache/spark/api/java/function/Function.java | 39 + .../apache/spark/api/java/function/Function2.java | 38 + .../api/java/function/PairFlatMapFunction.java | 46 + .../spark/api/java/function/PairFunction.java | 45 + .../spark/api/java/function/VoidFunction.scala | 33 + .../spark/api/java/function/WrappedFunction1.scala | 32 + .../spark/api/java/function/WrappedFunction2.scala | 32 + .../spark/api/python/PythonPartitioner.scala | 50 + .../org/apache/spark/api/python/PythonRDD.scala | 344 ++++++ .../spark/api/python/PythonWorkerFactory.scala | 132 ++ .../spark/broadcast/BitTorrentBroadcast.scala | 1057 ++++++++++++++++ .../org/apache/spark/broadcast/Broadcast.scala | 70 ++ .../apache/spark/broadcast/BroadcastFactory.scala | 30 + .../org/apache/spark/broadcast/HttpBroadcast.scala | 171 +++ .../org/apache/spark/broadcast/MultiTracker.scala | 409 ++++++ .../org/apache/spark/broadcast/SourceInfo.scala | 54 + .../org/apache/spark/broadcast/TreeBroadcast.scala | 602 +++++++++ .../spark/deploy/ApplicationDescription.scala | 32 + .../scala/org/apache/spark/deploy/Command.scala | 26 + .../org/apache/spark/deploy/DeployMessage.scala | 130 ++ .../org/apache/spark/deploy/ExecutorState.scala | 28 + .../org/apache/spark/deploy/JsonProtocol.scala | 86 ++ .../apache/spark/deploy/LocalSparkCluster.scala | 69 ++ .../org/apache/spark/deploy/SparkHadoopUtil.scala | 36 + .../main/scala/org/apache/spark/deploy/WebUI.scala | 47 + .../org/apache/spark/deploy/client/Client.scala | 145 +++ .../spark/deploy/client/ClientListener.scala | 35 + .../apache/spark/deploy/client/TestClient.scala | 51 + .../apache/spark/deploy/client/TestExecutor.scala | 27 + .../spark/deploy/master/ApplicationInfo.scala | 85 ++ .../spark/deploy/master/ApplicationSource.scala | 24 + .../spark/deploy/master/ApplicationState.scala | 28 + .../apache/spark/deploy/master/ExecutorInfo.scala | 32 + .../org/apache/spark/deploy/master/Master.scala | 386 ++++++ .../spark/deploy/master/MasterArguments.scala | 89 ++ .../apache/spark/deploy/master/MasterSource.scala | 25 + .../apache/spark/deploy/master/WorkerInfo.scala | 77 ++ .../apache/spark/deploy/master/WorkerState.scala | 24 + .../spark/deploy/master/ui/ApplicationPage.scala | 118 ++ .../apache/spark/deploy/master/ui/IndexPage.scala | 141 +++ .../spark/deploy/master/ui/MasterWebUI.scala | 80 ++ .../spark/deploy/worker/ExecutorRunner.scala | 199 +++ .../org/apache/spark/deploy/worker/Worker.scala | 213 ++++ .../spark/deploy/worker/WorkerArguments.scala | 153 +++ .../apache/spark/deploy/worker/WorkerSource.scala | 34 + .../apache/spark/deploy/worker/ui/IndexPage.scala | 115 ++ .../spark/deploy/worker/ui/WorkerWebUI.scala | 190 +++ .../scala/org/apache/spark/executor/Executor.scala | 269 ++++ .../apache/spark/executor/ExecutorBackend.scala | 28 + .../apache/spark/executor/ExecutorExitCode.scala | 60 + .../org/apache/spark/executor/ExecutorSource.scala | 55 + .../spark/executor/ExecutorURLClassLoader.scala | 31 + .../spark/executor/MesosExecutorBackend.scala | 95 ++ .../spark/executor/StandaloneExecutorBackend.scala | 107 ++ .../org/apache/spark/executor/TaskMetrics.scala | 105 ++ .../org/apache/spark/io/CompressionCodec.scala | 82 ++ .../org/apache/spark/metrics/MetricsConfig.scala | 100 ++ .../org/apache/spark/metrics/MetricsSystem.scala | 163 +++ .../apache/spark/metrics/sink/ConsoleSink.scala | 59 + .../org/apache/spark/metrics/sink/CsvSink.scala | 68 + .../org/apache/spark/metrics/sink/JmxSink.scala | 35 + .../apache/spark/metrics/sink/MetricsServlet.scala | 55 + .../scala/org/apache/spark/metrics/sink/Sink.scala | 23 + .../apache/spark/metrics/source/JvmSource.scala | 32 + .../org/apache/spark/metrics/source/Source.scala | 25 + .../org/apache/spark/network/BufferMessage.scala | 111 ++ .../org/apache/spark/network/Connection.scala | 586 +++++++++ .../apache/spark/network/ConnectionManager.scala | 720 +++++++++++ .../apache/spark/network/ConnectionManagerId.scala | 38 + .../spark/network/ConnectionManagerTest.scala | 102 ++ .../scala/org/apache/spark/network/Message.scala | 93 ++ .../org/apache/spark/network/MessageChunk.scala | 42 + .../apache/spark/network/MessageChunkHeader.scala | 75 ++ .../org/apache/spark/network/ReceiverTest.scala | 37 + .../org/apache/spark/network/SenderTest.scala | 70 ++ .../apache/spark/network/netty/FileHeader.scala | 74 ++ .../apache/spark/network/netty/ShuffleCopier.scala | 118 ++ .../apache/spark/network/netty/ShuffleSender.scala | 70 ++ core/src/main/scala/org/apache/spark/package.scala | 32 + .../spark/partial/ApproximateActionListener.scala | 87 ++ .../spark/partial/ApproximateEvaluator.scala | 27 + .../org/apache/spark/partial/BoundedDouble.scala | 25 + .../org/apache/spark/partial/CountEvaluator.scala | 55 + .../spark/partial/GroupedCountEvaluator.scala | 79 ++ .../spark/partial/GroupedMeanEvaluator.scala | 82 ++ .../apache/spark/partial/GroupedSumEvaluator.scala | 89 ++ .../org/apache/spark/partial/MeanEvaluator.scala | 58 + .../org/apache/spark/partial/PartialResult.scala | 137 ++ .../org/apache/spark/partial/StudentTCacher.scala | 43 + .../org/apache/spark/partial/SumEvaluator.scala | 68 + .../main/scala/org/apache/spark/rdd/BlockRDD.scala | 51 + .../scala/org/apache/spark/rdd/CartesianRDD.scala | 90 ++ .../scala/org/apache/spark/rdd/CheckpointRDD.scala | 155 +++ .../scala/org/apache/spark/rdd/CoGroupedRDD.scala | 144 +++ .../scala/org/apache/spark/rdd/CoalescedRDD.scala | 342 +++++ .../main/scala/org/apache/spark/rdd/EmptyRDD.scala | 33 + .../scala/org/apache/spark/rdd/FilteredRDD.scala | 33 + .../scala/org/apache/spark/rdd/FlatMappedRDD.scala | 33 + .../org/apache/spark/rdd/FlatMappedValuesRDD.scala | 36 + .../scala/org/apache/spark/rdd/GlommedRDD.scala | 29 + .../scala/org/apache/spark/rdd/HadoopRDD.scala | 137 ++ .../main/scala/org/apache/spark/rdd/JdbcRDD.scala | 120 ++ .../org/apache/spark/rdd/MapPartitionsRDD.scala | 37 + .../spark/rdd/MapPartitionsWithIndexRDD.scala | 41 + .../scala/org/apache/spark/rdd/MappedRDD.scala | 30 + .../org/apache/spark/rdd/MappedValuesRDD.scala | 34 + .../scala/org/apache/spark/rdd/NewHadoopRDD.scala | 126 ++ .../org/apache/spark/rdd/OrderedRDDFunctions.scala | 51 + .../apache/spark/rdd/ParallelCollectionRDD.scala | 151 +++ .../org/apache/spark/rdd/PartitionPruningRDD.scala | 72 ++ .../main/scala/org/apache/spark/rdd/PipedRDD.scala | 125 ++ .../scala/org/apache/spark/rdd/SampledRDD.scala | 66 + .../scala/org/apache/spark/rdd/ShuffledRDD.scala | 67 + .../scala/org/apache/spark/rdd/SubtractedRDD.scala | 129 ++ .../main/scala/org/apache/spark/rdd/UnionRDD.scala | 73 ++ .../org/apache/spark/rdd/ZippedPartitionsRDD.scala | 143 +++ .../scala/org/apache/spark/rdd/ZippedRDD.scala | 85 ++ .../org/apache/spark/scheduler/ActiveJob.scala | 39 + .../org/apache/spark/scheduler/DAGScheduler.scala | 849 +++++++++++++ .../apache/spark/scheduler/DAGSchedulerEvent.scala | 63 + .../spark/scheduler/DAGSchedulerSource.scala | 30 + .../apache/spark/scheduler/InputFormatInfo.scala | 178 +++ .../org/apache/spark/scheduler/JobListener.scala | 28 + .../org/apache/spark/scheduler/JobLogger.scala | 292 +++++ .../org/apache/spark/scheduler/JobResult.scala | 26 + .../org/apache/spark/scheduler/JobWaiter.scala | 66 + .../org/apache/spark/scheduler/MapStatus.scala | 44 + .../org/apache/spark/scheduler/ResultTask.scala | 134 ++ .../apache/spark/scheduler/ShuffleMapTask.scala | 189 +++ .../org/apache/spark/scheduler/SparkListener.scala | 204 +++ .../apache/spark/scheduler/SparkListenerBus.scala | 74 ++ .../org/apache/spark/scheduler/SplitInfo.scala | 78 ++ .../scala/org/apache/spark/scheduler/Stage.scala | 112 ++ .../org/apache/spark/scheduler/StageInfo.scala | 29 + .../scala/org/apache/spark/scheduler/Task.scala | 115 ++ .../org/apache/spark/scheduler/TaskLocation.scala | 34 + .../org/apache/spark/scheduler/TaskResult.scala | 72 ++ .../org/apache/spark/scheduler/TaskScheduler.scala | 52 + .../spark/scheduler/TaskSchedulerListener.scala | 45 + .../scala/org/apache/spark/scheduler/TaskSet.scala | 35 + .../spark/scheduler/cluster/ClusterScheduler.scala | 440 +++++++ .../scheduler/cluster/ClusterTaskSetManager.scala | 712 +++++++++++ .../scheduler/cluster/ExecutorLossReason.scala | 38 + .../org/apache/spark/scheduler/cluster/Pool.scala | 121 ++ .../spark/scheduler/cluster/Schedulable.scala | 48 + .../scheduler/cluster/SchedulableBuilder.scala | 137 ++ .../spark/scheduler/cluster/SchedulerBackend.scala | 37 + .../scheduler/cluster/SchedulingAlgorithm.scala | 81 ++ .../spark/scheduler/cluster/SchedulingMode.scala | 29 + .../cluster/SparkDeploySchedulerBackend.scala | 91 ++ .../cluster/StandaloneClusterMessage.scala | 63 + .../cluster/StandaloneSchedulerBackend.scala | 198 +++ .../spark/scheduler/cluster/TaskDescription.scala | 37 + .../apache/spark/scheduler/cluster/TaskInfo.scala | 72 ++ .../spark/scheduler/cluster/TaskLocality.scala | 32 + .../spark/scheduler/cluster/TaskSetManager.scala | 51 + .../spark/scheduler/cluster/WorkerOffer.scala | 24 + .../spark/scheduler/local/LocalScheduler.scala | 272 ++++ .../scheduler/local/LocalTaskSetManager.scala | 194 +++ .../mesos/CoarseMesosSchedulerBackend.scala | 286 +++++ .../scheduler/mesos/MesosSchedulerBackend.scala | 342 +++++ .../org/apache/spark/serializer/Serializer.scala | 112 ++ .../spark/serializer/SerializerManager.scala | 62 + .../org/apache/spark/storage/BlockException.scala | 22 + .../apache/spark/storage/BlockFetchTracker.scala | 27 + .../spark/storage/BlockFetcherIterator.scala | 348 ++++++ .../org/apache/spark/storage/BlockManager.scala | 1046 ++++++++++++++++ .../org/apache/spark/storage/BlockManagerId.scala | 118 ++ .../apache/spark/storage/BlockManagerMaster.scala | 178 +++ .../spark/storage/BlockManagerMasterActor.scala | 404 ++++++ .../spark/storage/BlockManagerMessages.scala | 110 ++ .../spark/storage/BlockManagerSlaveActor.scala | 39 + .../apache/spark/storage/BlockManagerSource.scala | 48 + .../apache/spark/storage/BlockManagerWorker.scala | 139 +++ .../org/apache/spark/storage/BlockMessage.scala | 223 ++++ .../apache/spark/storage/BlockMessageArray.scala | 159 +++ .../apache/spark/storage/BlockObjectWriter.scala | 65 + .../org/apache/spark/storage/BlockStore.scala | 61 + .../scala/org/apache/spark/storage/DiskStore.scala | 329 +++++ .../org/apache/spark/storage/MemoryStore.scala | 257 ++++ .../scala/org/apache/spark/storage/PutResult.scala | 26 + .../apache/spark/storage/ShuffleBlockManager.scala | 67 + .../org/apache/spark/storage/StorageLevel.scala | 146 +++ .../org/apache/spark/storage/StorageUtils.scala | 115 ++ .../org/apache/spark/storage/ThreadingTest.scala | 113 ++ .../scala/org/apache/spark/ui/JettyUtils.scala | 132 ++ core/src/main/scala/org/apache/spark/ui/Page.scala | 22 + .../main/scala/org/apache/spark/ui/SparkUI.scala | 87 ++ .../main/scala/org/apache/spark/ui/UIUtils.scala | 131 ++ .../org/apache/spark/ui/UIWorkloadGenerator.scala | 105 ++ .../org/apache/spark/ui/env/EnvironmentUI.scala | 91 ++ .../org/apache/spark/ui/exec/ExecutorsUI.scala | 136 ++ .../scala/org/apache/spark/ui/jobs/IndexPage.scala | 90 ++ .../apache/spark/ui/jobs/JobProgressListener.scala | 156 +++ .../org/apache/spark/ui/jobs/JobProgressUI.scala | 60 + .../scala/org/apache/spark/ui/jobs/PoolPage.scala | 32 + .../scala/org/apache/spark/ui/jobs/PoolTable.scala | 55 + .../scala/org/apache/spark/ui/jobs/StagePage.scala | 183 +++ .../org/apache/spark/ui/jobs/StageTable.scala | 107 ++ .../apache/spark/ui/storage/BlockManagerUI.scala | 41 + .../org/apache/spark/ui/storage/IndexPage.scala | 65 + .../org/apache/spark/ui/storage/RDDPage.scala | 132 ++ .../scala/org/apache/spark/util/AkkaUtils.scala | 72 ++ .../apache/spark/util/BoundedPriorityQueue.scala | 62 + .../apache/spark/util/ByteBufferInputStream.scala | 80 ++ .../main/scala/org/apache/spark/util/Clock.scala | 29 + .../org/apache/spark/util/CompletionIterator.scala | 42 + .../scala/org/apache/spark/util/Distribution.scala | 82 ++ .../scala/org/apache/spark/util/IdGenerator.scala | 31 + .../scala/org/apache/spark/util/IntParam.scala | 31 + .../scala/org/apache/spark/util/MemoryParam.scala | 34 + .../org/apache/spark/util/MetadataCleaner.scala | 61 + .../scala/org/apache/spark/util/MutablePair.scala | 36 + .../scala/org/apache/spark/util/NextIterator.scala | 88 ++ .../spark/util/RateLimitedOutputStream.scala | 79 ++ .../org/apache/spark/util/SerializableBuffer.scala | 54 + .../scala/org/apache/spark/util/StatCounter.scala | 131 ++ .../org/apache/spark/util/TimeStampedHashMap.scala | 122 ++ .../org/apache/spark/util/TimeStampedHashSet.scala | 86 ++ .../main/scala/org/apache/spark/util/Vector.scala | 139 +++ core/src/main/scala/spark/Accumulators.scala | 256 ---- core/src/main/scala/spark/Aggregator.scala | 61 - .../scala/spark/BlockStoreShuffleFetcher.scala | 89 -- core/src/main/scala/spark/CacheManager.scala | 82 -- core/src/main/scala/spark/ClosureCleaner.scala | 231 ---- core/src/main/scala/spark/Dependency.scala | 81 -- core/src/main/scala/spark/DoubleRDDFunctions.scala | 78 -- .../main/scala/spark/FetchFailedException.scala | 44 - core/src/main/scala/spark/HttpFileServer.scala | 62 - core/src/main/scala/spark/HttpServer.scala | 88 -- core/src/main/scala/spark/JavaSerializer.scala | 83 -- core/src/main/scala/spark/KryoSerializer.scala | 156 --- core/src/main/scala/spark/Logging.scala | 95 -- core/src/main/scala/spark/MapOutputTracker.scala | 338 ----- core/src/main/scala/spark/PairRDDFunctions.scala | 703 ----------- core/src/main/scala/spark/Partition.scala | 31 - core/src/main/scala/spark/Partitioner.scala | 135 -- core/src/main/scala/spark/RDD.scala | 957 -------------- core/src/main/scala/spark/RDDCheckpointData.scala | 130 -- .../scala/spark/SequenceFileRDDFunctions.scala | 107 -- .../main/scala/spark/SerializableWritable.scala | 42 - core/src/main/scala/spark/ShuffleFetcher.scala | 35 - core/src/main/scala/spark/SizeEstimator.scala | 283 ----- core/src/main/scala/spark/SparkContext.scala | 995 --------------- core/src/main/scala/spark/SparkEnv.scala | 241 ---- core/src/main/scala/spark/SparkException.scala | 24 - core/src/main/scala/spark/SparkFiles.java | 42 - core/src/main/scala/spark/SparkHadoopWriter.scala | 201 --- core/src/main/scala/spark/TaskContext.scala | 41 - core/src/main/scala/spark/TaskEndReason.scala | 51 - core/src/main/scala/spark/TaskState.scala | 51 - core/src/main/scala/spark/Utils.scala | 780 ------------ .../main/scala/spark/api/java/JavaDoubleRDD.scala | 167 --- .../main/scala/spark/api/java/JavaPairRDD.scala | 601 --------- core/src/main/scala/spark/api/java/JavaRDD.scala | 114 -- .../main/scala/spark/api/java/JavaRDDLike.scala | 426 ------- .../scala/spark/api/java/JavaSparkContext.scala | 418 ------- .../java/JavaSparkContextVarargsWorkaround.java | 64 - core/src/main/scala/spark/api/java/JavaUtils.scala | 28 - .../main/scala/spark/api/java/StorageLevels.java | 48 - .../api/java/function/DoubleFlatMapFunction.java | 37 - .../spark/api/java/function/DoubleFunction.java | 34 - .../spark/api/java/function/FlatMapFunction.scala | 28 - .../spark/api/java/function/FlatMapFunction2.scala | 28 - .../scala/spark/api/java/function/Function.java | 39 - .../scala/spark/api/java/function/Function2.java | 38 - .../api/java/function/PairFlatMapFunction.java | 46 - .../spark/api/java/function/PairFunction.java | 45 - .../spark/api/java/function/VoidFunction.scala | 33 - .../spark/api/java/function/WrappedFunction1.scala | 32 - .../spark/api/java/function/WrappedFunction2.scala | 32 - .../scala/spark/api/python/PythonPartitioner.scala | 50 - .../main/scala/spark/api/python/PythonRDD.scala | 344 ------ .../spark/api/python/PythonWorkerFactory.scala | 132 -- .../spark/broadcast/BitTorrentBroadcast.scala | 1057 ---------------- .../src/main/scala/spark/broadcast/Broadcast.scala | 70 -- .../scala/spark/broadcast/BroadcastFactory.scala | 30 - .../main/scala/spark/broadcast/HttpBroadcast.scala | 171 --- .../main/scala/spark/broadcast/MultiTracker.scala | 409 ------ .../main/scala/spark/broadcast/SourceInfo.scala | 54 - .../main/scala/spark/broadcast/TreeBroadcast.scala | 602 --------- .../spark/deploy/ApplicationDescription.scala | 32 - core/src/main/scala/spark/deploy/Command.scala | 26 - .../main/scala/spark/deploy/DeployMessage.scala | 130 -- .../main/scala/spark/deploy/ExecutorState.scala | 28 - .../src/main/scala/spark/deploy/JsonProtocol.scala | 86 -- .../scala/spark/deploy/LocalSparkCluster.scala | 69 -- .../main/scala/spark/deploy/SparkHadoopUtil.scala | 36 - core/src/main/scala/spark/deploy/WebUI.scala | 47 - .../main/scala/spark/deploy/client/Client.scala | 145 --- .../scala/spark/deploy/client/ClientListener.scala | 35 - .../scala/spark/deploy/client/TestClient.scala | 51 - .../scala/spark/deploy/client/TestExecutor.scala | 27 - .../spark/deploy/master/ApplicationInfo.scala | 85 -- .../spark/deploy/master/ApplicationSource.scala | 24 - .../spark/deploy/master/ApplicationState.scala | 28 - .../scala/spark/deploy/master/ExecutorInfo.scala | 32 - .../main/scala/spark/deploy/master/Master.scala | 386 ------ .../spark/deploy/master/MasterArguments.scala | 89 -- .../scala/spark/deploy/master/MasterSource.scala | 25 - .../scala/spark/deploy/master/WorkerInfo.scala | 77 -- .../scala/spark/deploy/master/WorkerState.scala | 24 - .../spark/deploy/master/ui/ApplicationPage.scala | 118 -- .../scala/spark/deploy/master/ui/IndexPage.scala | 141 --- .../scala/spark/deploy/master/ui/MasterWebUI.scala | 80 -- .../scala/spark/deploy/worker/ExecutorRunner.scala | 199 --- .../main/scala/spark/deploy/worker/Worker.scala | 213 ---- .../spark/deploy/worker/WorkerArguments.scala | 153 --- .../scala/spark/deploy/worker/WorkerSource.scala | 34 - .../scala/spark/deploy/worker/ui/IndexPage.scala | 115 -- .../scala/spark/deploy/worker/ui/WorkerWebUI.scala | 190 --- core/src/main/scala/spark/executor/Executor.scala | 269 ---- .../scala/spark/executor/ExecutorBackend.scala | 28 - .../scala/spark/executor/ExecutorExitCode.scala | 60 - .../main/scala/spark/executor/ExecutorSource.scala | 55 - .../spark/executor/ExecutorURLClassLoader.scala | 31 - .../spark/executor/MesosExecutorBackend.scala | 95 -- .../spark/executor/StandaloneExecutorBackend.scala | 107 -- .../main/scala/spark/executor/TaskMetrics.scala | 105 -- .../src/main/scala/spark/io/CompressionCodec.scala | 82 -- .../main/scala/spark/metrics/MetricsConfig.scala | 100 -- .../main/scala/spark/metrics/MetricsSystem.scala | 163 --- .../scala/spark/metrics/sink/ConsoleSink.scala | 59 - .../main/scala/spark/metrics/sink/CsvSink.scala | 68 - .../main/scala/spark/metrics/sink/JmxSink.scala | 35 - .../scala/spark/metrics/sink/MetricsServlet.scala | 55 - core/src/main/scala/spark/metrics/sink/Sink.scala | 23 - .../scala/spark/metrics/source/JvmSource.scala | 32 - .../main/scala/spark/metrics/source/Source.scala | 25 - .../main/scala/spark/network/BufferMessage.scala | 111 -- core/src/main/scala/spark/network/Connection.scala | 586 --------- .../scala/spark/network/ConnectionManager.scala | 720 ----------- .../scala/spark/network/ConnectionManagerId.scala | 38 - .../spark/network/ConnectionManagerTest.scala | 102 -- core/src/main/scala/spark/network/Message.scala | 93 -- .../main/scala/spark/network/MessageChunk.scala | 42 - .../scala/spark/network/MessageChunkHeader.scala | 75 -- .../main/scala/spark/network/ReceiverTest.scala | 37 - core/src/main/scala/spark/network/SenderTest.scala | 70 -- .../scala/spark/network/netty/FileHeader.scala | 74 -- .../scala/spark/network/netty/ShuffleCopier.scala | 118 -- .../scala/spark/network/netty/ShuffleSender.scala | 70 -- core/src/main/scala/spark/package.scala | 32 - .../spark/partial/ApproximateActionListener.scala | 87 -- .../scala/spark/partial/ApproximateEvaluator.scala | 27 - .../main/scala/spark/partial/BoundedDouble.scala | 25 - .../main/scala/spark/partial/CountEvaluator.scala | 55 - .../spark/partial/GroupedCountEvaluator.scala | 79 -- .../scala/spark/partial/GroupedMeanEvaluator.scala | 82 -- .../scala/spark/partial/GroupedSumEvaluator.scala | 89 -- .../main/scala/spark/partial/MeanEvaluator.scala | 58 - .../main/scala/spark/partial/PartialResult.scala | 137 -- .../main/scala/spark/partial/StudentTCacher.scala | 43 - .../main/scala/spark/partial/SumEvaluator.scala | 68 - core/src/main/scala/spark/rdd/BlockRDD.scala | 51 - core/src/main/scala/spark/rdd/CartesianRDD.scala | 90 -- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 155 --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 144 --- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 342 ----- core/src/main/scala/spark/rdd/EmptyRDD.scala | 33 - core/src/main/scala/spark/rdd/FilteredRDD.scala | 33 - core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 33 - .../main/scala/spark/rdd/FlatMappedValuesRDD.scala | 36 - core/src/main/scala/spark/rdd/GlommedRDD.scala | 29 - core/src/main/scala/spark/rdd/HadoopRDD.scala | 137 -- core/src/main/scala/spark/rdd/JdbcRDD.scala | 120 -- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 37 - .../spark/rdd/MapPartitionsWithIndexRDD.scala | 41 - core/src/main/scala/spark/rdd/MappedRDD.scala | 30 - .../src/main/scala/spark/rdd/MappedValuesRDD.scala | 34 - core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 126 -- .../main/scala/spark/rdd/OrderedRDDFunctions.scala | 51 - .../scala/spark/rdd/ParallelCollectionRDD.scala | 151 --- .../main/scala/spark/rdd/PartitionPruningRDD.scala | 72 -- core/src/main/scala/spark/rdd/PipedRDD.scala | 125 -- core/src/main/scala/spark/rdd/SampledRDD.scala | 66 - core/src/main/scala/spark/rdd/ShuffledRDD.scala | 67 - core/src/main/scala/spark/rdd/SubtractedRDD.scala | 129 -- core/src/main/scala/spark/rdd/UnionRDD.scala | 73 -- .../main/scala/spark/rdd/ZippedPartitionsRDD.scala | 143 --- core/src/main/scala/spark/rdd/ZippedRDD.scala | 85 -- .../src/main/scala/spark/scheduler/ActiveJob.scala | 39 - .../main/scala/spark/scheduler/DAGScheduler.scala | 849 ------------- .../scala/spark/scheduler/DAGSchedulerEvent.scala | 63 - .../scala/spark/scheduler/DAGSchedulerSource.scala | 30 - .../scala/spark/scheduler/InputFormatInfo.scala | 178 --- .../main/scala/spark/scheduler/JobListener.scala | 28 - .../src/main/scala/spark/scheduler/JobLogger.scala | 292 ----- .../src/main/scala/spark/scheduler/JobResult.scala | 26 - .../src/main/scala/spark/scheduler/JobWaiter.scala | 66 - .../src/main/scala/spark/scheduler/MapStatus.scala | 44 - .../main/scala/spark/scheduler/ResultTask.scala | 134 -- .../scala/spark/scheduler/ShuffleMapTask.scala | 189 --- .../main/scala/spark/scheduler/SparkListener.scala | 204 --- .../scala/spark/scheduler/SparkListenerBus.scala | 74 -- .../src/main/scala/spark/scheduler/SplitInfo.scala | 78 -- core/src/main/scala/spark/scheduler/Stage.scala | 112 -- .../src/main/scala/spark/scheduler/StageInfo.scala | 29 - core/src/main/scala/spark/scheduler/Task.scala | 115 -- .../main/scala/spark/scheduler/TaskLocation.scala | 34 - .../main/scala/spark/scheduler/TaskResult.scala | 72 -- .../main/scala/spark/scheduler/TaskScheduler.scala | 52 - .../spark/scheduler/TaskSchedulerListener.scala | 45 - core/src/main/scala/spark/scheduler/TaskSet.scala | 35 - .../spark/scheduler/cluster/ClusterScheduler.scala | 440 ------- .../scheduler/cluster/ClusterTaskSetManager.scala | 712 ----------- .../scheduler/cluster/ExecutorLossReason.scala | 38 - .../main/scala/spark/scheduler/cluster/Pool.scala | 121 -- .../spark/scheduler/cluster/Schedulable.scala | 48 - .../scheduler/cluster/SchedulableBuilder.scala | 137 -- .../spark/scheduler/cluster/SchedulerBackend.scala | 37 - .../scheduler/cluster/SchedulingAlgorithm.scala | 81 -- .../spark/scheduler/cluster/SchedulingMode.scala | 29 - .../cluster/SparkDeploySchedulerBackend.scala | 90 -- .../cluster/StandaloneClusterMessage.scala | 63 - .../cluster/StandaloneSchedulerBackend.scala | 198 --- .../spark/scheduler/cluster/TaskDescription.scala | 37 - .../scala/spark/scheduler/cluster/TaskInfo.scala | 72 -- .../spark/scheduler/cluster/TaskLocality.scala | 32 - .../spark/scheduler/cluster/TaskSetManager.scala | 51 - .../spark/scheduler/cluster/WorkerOffer.scala | 24 - .../spark/scheduler/local/LocalScheduler.scala | 272 ---- .../scheduler/local/LocalTaskSetManager.scala | 194 --- .../mesos/CoarseMesosSchedulerBackend.scala | 284 ----- .../scheduler/mesos/MesosSchedulerBackend.scala | 342 ----- .../main/scala/spark/serializer/Serializer.scala | 112 -- .../scala/spark/serializer/SerializerManager.scala | 62 - .../main/scala/spark/storage/BlockException.scala | 22 - .../scala/spark/storage/BlockFetchTracker.scala | 27 - .../scala/spark/storage/BlockFetcherIterator.scala | 348 ------ .../main/scala/spark/storage/BlockManager.scala | 1046 ---------------- .../main/scala/spark/storage/BlockManagerId.scala | 118 -- .../scala/spark/storage/BlockManagerMaster.scala | 178 --- .../spark/storage/BlockManagerMasterActor.scala | 404 ------ .../scala/spark/storage/BlockManagerMessages.scala | 110 -- .../spark/storage/BlockManagerSlaveActor.scala | 39 - .../scala/spark/storage/BlockManagerSource.scala | 48 - .../scala/spark/storage/BlockManagerWorker.scala | 139 --- .../main/scala/spark/storage/BlockMessage.scala | 223 ---- .../scala/spark/storage/BlockMessageArray.scala | 159 --- .../scala/spark/storage/BlockObjectWriter.scala | 65 - core/src/main/scala/spark/storage/BlockStore.scala | 61 - core/src/main/scala/spark/storage/DiskStore.scala | 329 ----- .../src/main/scala/spark/storage/MemoryStore.scala | 257 ---- core/src/main/scala/spark/storage/PutResult.scala | 26 - .../scala/spark/storage/ShuffleBlockManager.scala | 67 - .../main/scala/spark/storage/StorageLevel.scala | 146 --- .../main/scala/spark/storage/StorageUtils.scala | 115 -- .../main/scala/spark/storage/ThreadingTest.scala | 113 -- core/src/main/scala/spark/ui/JettyUtils.scala | 132 -- core/src/main/scala/spark/ui/Page.scala | 22 - core/src/main/scala/spark/ui/SparkUI.scala | 87 -- core/src/main/scala/spark/ui/UIUtils.scala | 131 -- .../main/scala/spark/ui/UIWorkloadGenerator.scala | 105 -- .../main/scala/spark/ui/env/EnvironmentUI.scala | 91 -- .../src/main/scala/spark/ui/exec/ExecutorsUI.scala | 136 -- core/src/main/scala/spark/ui/jobs/IndexPage.scala | 90 -- .../scala/spark/ui/jobs/JobProgressListener.scala | 156 --- .../main/scala/spark/ui/jobs/JobProgressUI.scala | 60 - core/src/main/scala/spark/ui/jobs/PoolPage.scala | 32 - core/src/main/scala/spark/ui/jobs/PoolTable.scala | 55 - core/src/main/scala/spark/ui/jobs/StagePage.scala | 183 --- core/src/main/scala/spark/ui/jobs/StageTable.scala | 107 -- .../scala/spark/ui/storage/BlockManagerUI.scala | 41 - .../main/scala/spark/ui/storage/IndexPage.scala | 65 - core/src/main/scala/spark/ui/storage/RDDPage.scala | 132 -- core/src/main/scala/spark/util/AkkaUtils.scala | 72 -- .../scala/spark/util/BoundedPriorityQueue.scala | 62 - .../scala/spark/util/ByteBufferInputStream.scala | 80 -- core/src/main/scala/spark/util/Clock.scala | 29 - .../main/scala/spark/util/CompletionIterator.scala | 42 - core/src/main/scala/spark/util/Distribution.scala | 82 -- core/src/main/scala/spark/util/IdGenerator.scala | 31 - core/src/main/scala/spark/util/IntParam.scala | 31 - core/src/main/scala/spark/util/MemoryParam.scala | 34 - .../main/scala/spark/util/MetadataCleaner.scala | 61 - core/src/main/scala/spark/util/MutablePair.scala | 36 - core/src/main/scala/spark/util/NextIterator.scala | 88 -- .../scala/spark/util/RateLimitedOutputStream.scala | 79 -- .../main/scala/spark/util/SerializableBuffer.scala | 54 - core/src/main/scala/spark/util/StatCounter.scala | 131 -- .../main/scala/spark/util/TimeStampedHashMap.scala | 121 -- .../main/scala/spark/util/TimeStampedHashSet.scala | 86 -- core/src/main/scala/spark/util/Vector.scala | 139 --- .../test/resources/test_metrics_config.properties | 2 +- .../test/resources/test_metrics_system.properties | 6 +- .../scala/org/apache/spark/AccumulatorSuite.scala | 143 +++ .../scala/org/apache/spark/BroadcastSuite.scala | 39 + .../scala/org/apache/spark/CheckpointSuite.scala | 392 ++++++ .../org/apache/spark/ClosureCleanerSuite.scala | 146 +++ .../scala/org/apache/spark/DistributedSuite.scala | 362 ++++++ .../test/scala/org/apache/spark/DriverSuite.scala | 54 + .../test/scala/org/apache/spark/FailureSuite.scala | 127 ++ .../scala/org/apache/spark/FileServerSuite.scala | 123 ++ .../test/scala/org/apache/spark/FileSuite.scala | 212 ++++ .../test/scala/org/apache/spark/JavaAPISuite.java | 865 +++++++++++++ .../org/apache/spark/KryoSerializerSuite.scala | 208 ++++ .../scala/org/apache/spark/LocalSparkContext.scala | 68 + .../org/apache/spark/MapOutputTrackerSuite.scala | 136 ++ .../org/apache/spark/PairRDDFunctionsSuite.scala | 299 +++++ .../apache/spark/PartitionPruningRDDSuite.scala | 28 + .../scala/org/apache/spark/PartitioningSuite.scala | 150 +++ .../scala/org/apache/spark/PipedRDDSuite.scala | 93 ++ .../src/test/scala/org/apache/spark/RDDSuite.scala | 389 ++++++ .../org/apache/spark/SharedSparkContext.scala | 42 + .../scala/org/apache/spark/ShuffleNettySuite.scala | 34 + .../test/scala/org/apache/spark/ShuffleSuite.scala | 210 ++++ .../org/apache/spark/SizeEstimatorSuite.scala | 164 +++ .../test/scala/org/apache/spark/SortingSuite.scala | 123 ++ .../org/apache/spark/SparkContextInfoSuite.scala | 60 + .../scala/org/apache/spark/ThreadingSuite.scala | 152 +++ .../scala/org/apache/spark/UnpersistSuite.scala | 47 + .../test/scala/org/apache/spark/UtilsSuite.scala | 139 +++ .../org/apache/spark/ZippedPartitionsSuite.scala | 50 + .../apache/spark/io/CompressionCodecSuite.scala | 62 + .../apache/spark/metrics/MetricsConfigSuite.scala | 89 ++ .../apache/spark/metrics/MetricsSystemSuite.scala | 54 + .../scala/org/apache/spark/rdd/JdbcRDDSuite.scala | 73 ++ .../spark/rdd/ParallelCollectionSplitSuite.scala | 212 ++++ .../apache/spark/scheduler/DAGSchedulerSuite.scala | 421 +++++++ .../apache/spark/scheduler/JobLoggerSuite.scala | 121 ++ .../spark/scheduler/SparkListenerSuite.scala | 102 ++ .../apache/spark/scheduler/TaskContextSuite.scala | 49 + .../scheduler/cluster/ClusterSchedulerSuite.scala | 266 ++++ .../cluster/ClusterTaskSetManagerSuite.scala | 273 ++++ .../apache/spark/scheduler/cluster/FakeTask.scala | 26 + .../scheduler/local/LocalSchedulerSuite.scala | 223 ++++ .../apache/spark/storage/BlockManagerSuite.scala | 666 ++++++++++ .../test/scala/org/apache/spark/ui/UISuite.scala | 47 + .../org/apache/spark/util/DistributionSuite.scala | 42 + .../scala/org/apache/spark/util/FakeClock.scala | 26 + .../org/apache/spark/util/NextIteratorSuite.scala | 85 ++ .../spark/util/RateLimitedOutputStreamSuite.scala | 40 + core/src/test/scala/spark/AccumulatorSuite.scala | 143 --- core/src/test/scala/spark/BroadcastSuite.scala | 39 - core/src/test/scala/spark/CheckpointSuite.scala | 392 ------ .../src/test/scala/spark/ClosureCleanerSuite.scala | 146 --- core/src/test/scala/spark/DistributedSuite.scala | 362 ------ core/src/test/scala/spark/DriverSuite.scala | 54 - core/src/test/scala/spark/FailureSuite.scala | 127 -- core/src/test/scala/spark/FileServerSuite.scala | 123 -- core/src/test/scala/spark/FileSuite.scala | 212 ---- core/src/test/scala/spark/JavaAPISuite.java | 865 ------------- .../src/test/scala/spark/KryoSerializerSuite.scala | 208 ---- core/src/test/scala/spark/LocalSparkContext.scala | 68 - .../test/scala/spark/MapOutputTrackerSuite.scala | 136 -- .../test/scala/spark/PairRDDFunctionsSuite.scala | 299 ----- .../scala/spark/PartitionPruningRDDSuite.scala | 28 - core/src/test/scala/spark/PartitioningSuite.scala | 150 --- core/src/test/scala/spark/PipedRDDSuite.scala | 93 -- core/src/test/scala/spark/RDDSuite.scala | 389 ------ core/src/test/scala/spark/SharedSparkContext.scala | 42 - core/src/test/scala/spark/ShuffleNettySuite.scala | 34 - core/src/test/scala/spark/ShuffleSuite.scala | 210 ---- core/src/test/scala/spark/SizeEstimatorSuite.scala | 164 --- core/src/test/scala/spark/SortingSuite.scala | 123 -- .../test/scala/spark/SparkContextInfoSuite.scala | 60 - core/src/test/scala/spark/ThreadingSuite.scala | 152 --- core/src/test/scala/spark/UnpersistSuite.scala | 47 - core/src/test/scala/spark/UtilsSuite.scala | 139 --- .../test/scala/spark/ZippedPartitionsSuite.scala | 50 - .../scala/spark/io/CompressionCodecSuite.scala | 62 - .../scala/spark/metrics/MetricsConfigSuite.scala | 89 -- .../scala/spark/metrics/MetricsSystemSuite.scala | 53 - core/src/test/scala/spark/rdd/JdbcRDDSuite.scala | 73 -- .../spark/rdd/ParallelCollectionSplitSuite.scala | 212 ---- .../scala/spark/scheduler/DAGSchedulerSuite.scala | 421 ------- .../scala/spark/scheduler/JobLoggerSuite.scala | 121 -- .../scala/spark/scheduler/SparkListenerSuite.scala | 102 -- .../scala/spark/scheduler/TaskContextSuite.scala | 49 - .../scheduler/cluster/ClusterSchedulerSuite.scala | 266 ---- .../cluster/ClusterTaskSetManagerSuite.scala | 273 ---- .../scala/spark/scheduler/cluster/FakeTask.scala | 26 - .../scheduler/local/LocalSchedulerSuite.scala | 223 ---- .../scala/spark/storage/BlockManagerSuite.scala | 665 ---------- core/src/test/scala/spark/ui/UISuite.scala | 47 - .../test/scala/spark/util/DistributionSuite.scala | 42 - core/src/test/scala/spark/util/FakeClock.scala | 26 - .../test/scala/spark/util/NextIteratorSuite.scala | 85 -- .../spark/util/RateLimitedOutputStreamSuite.scala | 40 - docs/_layouts/global.html | 2 +- examples/pom.xml | 14 +- .../java/org/apache/spark/examples/JavaHdfsLR.java | 140 +++ .../java/org/apache/spark/examples/JavaKMeans.java | 131 ++ .../org/apache/spark/examples/JavaLogQuery.java | 131 ++ .../org/apache/spark/examples/JavaPageRank.java | 115 ++ .../org/apache/spark/examples/JavaSparkPi.java | 65 + .../java/org/apache/spark/examples/JavaTC.java | 97 ++ .../org/apache/spark/examples/JavaWordCount.java | 66 + .../org/apache/spark/mllib/examples/JavaALS.java | 87 ++ .../apache/spark/mllib/examples/JavaKMeans.java | 81 ++ .../org/apache/spark/mllib/examples/JavaLR.java | 85 ++ .../streaming/examples/JavaFlumeEventCount.java | 68 + .../streaming/examples/JavaNetworkWordCount.java | 79 ++ .../spark/streaming/examples/JavaQueueStream.java | 80 ++ .../src/main/java/spark/examples/JavaHdfsLR.java | 140 --- .../src/main/java/spark/examples/JavaKMeans.java | 131 -- .../src/main/java/spark/examples/JavaLogQuery.java | 131 -- .../src/main/java/spark/examples/JavaPageRank.java | 115 -- .../src/main/java/spark/examples/JavaSparkPi.java | 65 - examples/src/main/java/spark/examples/JavaTC.java | 97 -- .../main/java/spark/examples/JavaWordCount.java | 66 - .../main/java/spark/mllib/examples/JavaALS.java | 87 -- .../main/java/spark/mllib/examples/JavaKMeans.java | 81 -- .../src/main/java/spark/mllib/examples/JavaLR.java | 85 -- .../streaming/examples/JavaFlumeEventCount.java | 68 - .../streaming/examples/JavaNetworkWordCount.java | 79 -- .../spark/streaming/examples/JavaQueueStream.java | 80 -- .../org/apache/spark/examples/BroadcastTest.scala | 50 + .../org/apache/spark/examples/CassandraTest.scala | 213 ++++ .../spark/examples/ExceptionHandlingTest.scala | 38 + .../org/apache/spark/examples/GroupByTest.scala | 57 + .../org/apache/spark/examples/HBaseTest.scala | 52 + .../scala/org/apache/spark/examples/HdfsTest.scala | 37 + .../scala/org/apache/spark/examples/LocalALS.scala | 140 +++ .../org/apache/spark/examples/LocalFileLR.scala | 55 + .../org/apache/spark/examples/LocalKMeans.scala | 99 ++ .../scala/org/apache/spark/examples/LocalLR.scala | 63 + .../scala/org/apache/spark/examples/LocalPi.scala | 34 + .../scala/org/apache/spark/examples/LogQuery.scala | 85 ++ .../apache/spark/examples/MultiBroadcastTest.scala | 53 + .../spark/examples/SimpleSkewedGroupByTest.scala | 71 ++ .../apache/spark/examples/SkewedGroupByTest.scala | 61 + .../scala/org/apache/spark/examples/SparkALS.scala | 143 +++ .../org/apache/spark/examples/SparkHdfsLR.scala | 78 ++ .../org/apache/spark/examples/SparkKMeans.scala | 91 ++ .../scala/org/apache/spark/examples/SparkLR.scala | 71 ++ .../org/apache/spark/examples/SparkPageRank.scala | 46 + .../scala/org/apache/spark/examples/SparkPi.scala | 43 + .../scala/org/apache/spark/examples/SparkTC.scala | 75 ++ .../spark/examples/bagel/PageRankUtils.scala | 123 ++ .../spark/examples/bagel/WikipediaPageRank.scala | 101 ++ .../bagel/WikipediaPageRankStandalone.scala | 223 ++++ .../spark/streaming/examples/ActorWordCount.scala | 175 +++ .../spark/streaming/examples/FlumeEventCount.scala | 61 + .../spark/streaming/examples/HdfsWordCount.scala | 54 + .../spark/streaming/examples/KafkaWordCount.scala | 98 ++ .../streaming/examples/NetworkWordCount.scala | 54 + .../spark/streaming/examples/QueueStream.scala | 57 + .../spark/streaming/examples/RawNetworkGrep.scala | 64 + .../examples/StatefulNetworkWordCount.scala | 67 + .../streaming/examples/TwitterAlgebirdCMS.scala | 110 ++ .../streaming/examples/TwitterAlgebirdHLL.scala | 88 ++ .../streaming/examples/TwitterPopularTags.scala | 70 ++ .../spark/streaming/examples/ZeroMQWordCount.scala | 91 ++ .../examples/clickstream/PageViewGenerator.scala | 102 ++ .../examples/clickstream/PageViewStream.scala | 101 ++ .../main/scala/spark/examples/BroadcastTest.scala | 50 - .../main/scala/spark/examples/CassandraTest.scala | 213 ---- .../spark/examples/ExceptionHandlingTest.scala | 38 - .../main/scala/spark/examples/GroupByTest.scala | 57 - .../src/main/scala/spark/examples/HBaseTest.scala | 52 - .../src/main/scala/spark/examples/HdfsTest.scala | 37 - .../src/main/scala/spark/examples/LocalALS.scala | 140 --- .../main/scala/spark/examples/LocalFileLR.scala | 55 - .../main/scala/spark/examples/LocalKMeans.scala | 99 -- .../src/main/scala/spark/examples/LocalLR.scala | 63 - .../src/main/scala/spark/examples/LocalPi.scala | 34 - .../src/main/scala/spark/examples/LogQuery.scala | 85 -- .../scala/spark/examples/MultiBroadcastTest.scala | 53 - .../spark/examples/SimpleSkewedGroupByTest.scala | 71 -- .../scala/spark/examples/SkewedGroupByTest.scala | 61 - .../src/main/scala/spark/examples/SparkALS.scala | 143 --- .../main/scala/spark/examples/SparkHdfsLR.scala | 78 -- .../main/scala/spark/examples/SparkKMeans.scala | 91 -- .../src/main/scala/spark/examples/SparkLR.scala | 71 -- .../main/scala/spark/examples/SparkPageRank.scala | 46 - .../src/main/scala/spark/examples/SparkPi.scala | 43 - .../src/main/scala/spark/examples/SparkTC.scala | 75 -- .../scala/spark/examples/bagel/PageRankUtils.scala | 123 -- .../spark/examples/bagel/WikipediaPageRank.scala | 101 -- .../bagel/WikipediaPageRankStandalone.scala | 223 ---- .../spark/streaming/examples/ActorWordCount.scala | 175 --- .../spark/streaming/examples/FlumeEventCount.scala | 61 - .../spark/streaming/examples/HdfsWordCount.scala | 54 - .../spark/streaming/examples/KafkaWordCount.scala | 98 -- .../streaming/examples/NetworkWordCount.scala | 54 - .../spark/streaming/examples/QueueStream.scala | 57 - .../spark/streaming/examples/RawNetworkGrep.scala | 64 - .../examples/StatefulNetworkWordCount.scala | 67 - .../streaming/examples/TwitterAlgebirdCMS.scala | 110 -- .../streaming/examples/TwitterAlgebirdHLL.scala | 88 -- .../streaming/examples/TwitterPopularTags.scala | 70 -- .../spark/streaming/examples/ZeroMQWordCount.scala | 91 -- .../examples/clickstream/PageViewGenerator.scala | 102 -- .../examples/clickstream/PageViewStream.scala | 101 -- mllib/pom.xml | 6 +- .../mllib/classification/ClassificationModel.scala | 21 + .../mllib/classification/LogisticRegression.scala | 188 +++ .../apache/spark/mllib/classification/SVM.scala | 187 +++ .../org/apache/spark/mllib/clustering/KMeans.scala | 335 +++++ .../spark/mllib/clustering/KMeansModel.scala | 44 + .../spark/mllib/clustering/LocalKMeans.scala | 105 ++ .../apache/spark/mllib/optimization/Gradient.scala | 98 ++ .../spark/mllib/optimization/GradientDescent.scala | 166 +++ .../spark/mllib/optimization/Optimizer.scala | 29 + .../apache/spark/mllib/optimization/Updater.scala | 99 ++ .../apache/spark/mllib/recommendation/ALS.scala | 453 +++++++ .../recommendation/MatrixFactorizationModel.scala | 49 + .../regression/GeneralizedLinearAlgorithm.scala | 159 +++ .../spark/mllib/regression/LabeledPoint.scala | 26 + .../org/apache/spark/mllib/regression/Lasso.scala | 210 ++++ .../spark/mllib/regression/LinearRegression.scala | 167 +++ .../spark/mllib/regression/RegressionModel.scala | 38 + .../spark/mllib/regression/RidgeRegression.scala | 213 ++++ .../apache/spark/mllib/util/DataValidators.scala | 42 + .../spark/mllib/util/KMeansDataGenerator.scala | 84 ++ .../spark/mllib/util/LinearDataGenerator.scala | 132 ++ .../util/LogisticRegressionDataGenerator.scala | 81 ++ .../apache/spark/mllib/util/MFDataGenerator.scala | 113 ++ .../org/apache/spark/mllib/util/MLUtils.scala | 122 ++ .../apache/spark/mllib/util/SVMDataGenerator.scala | 50 + .../mllib/classification/ClassificationModel.scala | 21 - .../mllib/classification/LogisticRegression.scala | 188 --- .../scala/spark/mllib/classification/SVM.scala | 187 --- .../main/scala/spark/mllib/clustering/KMeans.scala | 335 ----- .../scala/spark/mllib/clustering/KMeansModel.scala | 44 - .../scala/spark/mllib/clustering/LocalKMeans.scala | 105 -- .../scala/spark/mllib/optimization/Gradient.scala | 98 -- .../spark/mllib/optimization/GradientDescent.scala | 166 --- .../scala/spark/mllib/optimization/Optimizer.scala | 29 - .../scala/spark/mllib/optimization/Updater.scala | 99 -- .../scala/spark/mllib/recommendation/ALS.scala | 453 ------- .../recommendation/MatrixFactorizationModel.scala | 49 - .../regression/GeneralizedLinearAlgorithm.scala | 159 --- .../spark/mllib/regression/LabeledPoint.scala | 26 - .../main/scala/spark/mllib/regression/Lasso.scala | 210 ---- .../spark/mllib/regression/LinearRegression.scala | 167 --- .../spark/mllib/regression/RegressionModel.scala | 38 - .../spark/mllib/regression/RidgeRegression.scala | 213 ---- .../scala/spark/mllib/util/DataValidators.scala | 42 - .../spark/mllib/util/KMeansDataGenerator.scala | 84 -- .../spark/mllib/util/LinearDataGenerator.scala | 132 -- .../util/LogisticRegressionDataGenerator.scala | 81 -- .../scala/spark/mllib/util/MFDataGenerator.scala | 113 -- .../src/main/scala/spark/mllib/util/MLUtils.scala | 122 -- .../scala/spark/mllib/util/SVMDataGenerator.scala | 50 - .../JavaLogisticRegressionSuite.java | 98 ++ .../spark/mllib/classification/JavaSVMSuite.java | 98 ++ .../spark/mllib/clustering/JavaKMeansSuite.java | 115 ++ .../spark/mllib/recommendation/JavaALSSuite.java | 110 ++ .../spark/mllib/regression/JavaLassoSuite.java | 97 ++ .../regression/JavaLinearRegressionSuite.java | 94 ++ .../mllib/regression/JavaRidgeRegressionSuite.java | 110 ++ .../JavaLogisticRegressionSuite.java | 98 -- .../spark/mllib/classification/JavaSVMSuite.java | 98 -- .../spark/mllib/clustering/JavaKMeansSuite.java | 115 -- .../spark/mllib/recommendation/JavaALSSuite.java | 110 -- .../spark/mllib/regression/JavaLassoSuite.java | 97 -- .../regression/JavaLinearRegressionSuite.java | 94 -- .../mllib/regression/JavaRidgeRegressionSuite.java | 110 -- .../classification/LogisticRegressionSuite.scala | 150 +++ .../spark/mllib/classification/SVMSuite.scala | 169 +++ .../spark/mllib/clustering/KMeansSuite.scala | 173 +++ .../spark/mllib/recommendation/ALSSuite.scala | 125 ++ .../apache/spark/mllib/regression/LassoSuite.scala | 121 ++ .../mllib/regression/LinearRegressionSuite.scala | 72 ++ .../mllib/regression/RidgeRegressionSuite.scala | 90 ++ .../classification/LogisticRegressionSuite.scala | 150 --- .../spark/mllib/classification/SVMSuite.scala | 169 --- .../scala/spark/mllib/clustering/KMeansSuite.scala | 173 --- .../spark/mllib/recommendation/ALSSuite.scala | 125 -- .../scala/spark/mllib/regression/LassoSuite.scala | 121 -- .../mllib/regression/LinearRegressionSuite.scala | 72 -- .../mllib/regression/RidgeRegressionSuite.scala | 90 -- pom.xml | 14 +- project/SparkBuild.scala | 12 +- python/pyspark/context.py | 4 +- python/pyspark/files.py | 2 +- python/pyspark/java_gateway.py | 4 +- repl-bin/pom.xml | 12 +- repl/pom.xml | 12 +- .../apache/spark/repl/ExecutorClassLoader.scala | 124 ++ .../main/scala/org/apache/spark/repl/Main.scala | 33 + .../scala/org/apache/spark/repl/SparkHelper.scala | 5 + .../scala/org/apache/spark/repl/SparkILoop.scala | 1008 +++++++++++++++ .../scala/org/apache/spark/repl/SparkIMain.scala | 1160 +++++++++++++++++ .../org/apache/spark/repl/SparkISettings.scala | 63 + .../scala/org/apache/spark/repl/SparkImports.scala | 214 ++++ .../apache/spark/repl/SparkJLineCompletion.scala | 379 ++++++ .../org/apache/spark/repl/SparkJLineReader.scala | 79 ++ .../apache/spark/repl/SparkMemberHandlers.scala | 207 ++++ .../scala/spark/repl/ExecutorClassLoader.scala | 124 -- repl/src/main/scala/spark/repl/Main.scala | 33 - repl/src/main/scala/spark/repl/SparkHelper.scala | 5 - repl/src/main/scala/spark/repl/SparkILoop.scala | 1008 --------------- repl/src/main/scala/spark/repl/SparkIMain.scala | 1160 ----------------- .../src/main/scala/spark/repl/SparkISettings.scala | 63 - repl/src/main/scala/spark/repl/SparkImports.scala | 214 ---- .../scala/spark/repl/SparkJLineCompletion.scala | 379 ------ .../main/scala/spark/repl/SparkJLineReader.scala | 79 -- .../scala/spark/repl/SparkMemberHandlers.scala | 207 ---- .../scala/org/apache/spark/repl/ReplSuite.scala | 207 ++++ repl/src/test/scala/spark/repl/ReplSuite.scala | 207 ---- spark-executor | 2 +- spark-shell | 2 +- spark-shell.cmd | 2 +- streaming/pom.xml | 6 +- .../org/apache/spark/streaming/Checkpoint.scala | 190 +++ .../scala/org/apache/spark/streaming/DStream.scala | 702 +++++++++++ .../spark/streaming/DStreamCheckpointData.scala | 110 ++ .../org/apache/spark/streaming/DStreamGraph.scala | 167 +++ .../org/apache/spark/streaming/Duration.scala | 83 ++ .../org/apache/spark/streaming/Interval.scala | 59 + .../scala/org/apache/spark/streaming/Job.scala | 41 + .../org/apache/spark/streaming/JobManager.scala | 88 ++ .../spark/streaming/NetworkInputTracker.scala | 173 +++ .../spark/streaming/PairDStreamFunctions.scala | 534 ++++++++ .../org/apache/spark/streaming/Scheduler.scala | 131 ++ .../apache/spark/streaming/StreamingContext.scala | 563 +++++++++ .../scala/org/apache/spark/streaming/Time.scala | 72 ++ .../spark/streaming/api/java/JavaDStream.scala | 102 ++ .../spark/streaming/api/java/JavaDStreamLike.scala | 316 +++++ .../spark/streaming/api/java/JavaPairDStream.scala | 613 +++++++++ .../streaming/api/java/JavaStreamingContext.scala | 614 +++++++++ .../spark/streaming/dstream/CoGroupedDStream.scala | 57 + .../streaming/dstream/ConstantInputDStream.scala | 36 + .../spark/streaming/dstream/FileInputDStream.scala | 199 +++ .../spark/streaming/dstream/FilteredDStream.scala | 38 + .../streaming/dstream/FlatMapValuedDStream.scala | 37 + .../streaming/dstream/FlatMappedDStream.scala | 37 + .../streaming/dstream/FlumeInputDStream.scala | 154 +++ .../spark/streaming/dstream/ForEachDStream.scala | 45 + .../spark/streaming/dstream/GlommedDStream.scala | 34 + .../spark/streaming/dstream/InputDStream.scala | 70 ++ .../streaming/dstream/KafkaInputDStream.scala | 141 +++ .../streaming/dstream/MapPartitionedDStream.scala | 38 + .../spark/streaming/dstream/MapValuedDStream.scala | 38 + .../spark/streaming/dstream/MappedDStream.scala | 37 + .../streaming/dstream/NetworkInputDStream.scala | 272 ++++ .../streaming/dstream/PluggableInputDStream.scala | 30 + .../streaming/dstream/QueueInputDStream.scala | 59 + .../spark/streaming/dstream/RawInputDStream.scala | 108 ++ .../streaming/dstream/ReducedWindowedDStream.scala | 174 +++ .../spark/streaming/dstream/ShuffledDStream.scala | 44 + .../streaming/dstream/SocketInputDStream.scala | 94 ++ .../spark/streaming/dstream/StateDStream.scala | 109 ++ .../streaming/dstream/TransformedDStream.scala | 36 + .../streaming/dstream/TwitterInputDStream.scala | 99 ++ .../spark/streaming/dstream/UnionDStream.scala | 57 + .../spark/streaming/dstream/WindowedDStream.scala | 57 + .../spark/streaming/receivers/ActorReceiver.scala | 175 +++ .../spark/streaming/receivers/ZeroMQReceiver.scala | 50 + .../org/apache/spark/streaming/util/Clock.scala | 101 ++ .../spark/streaming/util/MasterFailureTest.scala | 414 +++++++ .../spark/streaming/util/RawTextHelper.scala | 115 ++ .../spark/streaming/util/RawTextSender.scala | 77 ++ .../spark/streaming/util/RecurringTimer.scala | 94 ++ .../main/scala/spark/streaming/Checkpoint.scala | 190 --- .../src/main/scala/spark/streaming/DStream.scala | 700 ----------- .../spark/streaming/DStreamCheckpointData.scala | 110 -- .../main/scala/spark/streaming/DStreamGraph.scala | 167 --- .../src/main/scala/spark/streaming/Duration.scala | 83 -- .../src/main/scala/spark/streaming/Interval.scala | 59 - streaming/src/main/scala/spark/streaming/Job.scala | 41 - .../main/scala/spark/streaming/JobManager.scala | 88 -- .../spark/streaming/NetworkInputTracker.scala | 173 --- .../spark/streaming/PairDStreamFunctions.scala | 534 -------- .../src/main/scala/spark/streaming/Scheduler.scala | 130 -- .../scala/spark/streaming/StreamingContext.scala | 563 --------- .../src/main/scala/spark/streaming/Time.scala | 72 -- .../spark/streaming/api/java/JavaDStream.scala | 102 -- .../spark/streaming/api/java/JavaDStreamLike.scala | 316 ----- .../spark/streaming/api/java/JavaPairDStream.scala | 613 --------- .../streaming/api/java/JavaStreamingContext.scala | 613 --------- .../spark/streaming/dstream/CoGroupedDStream.scala | 57 - .../streaming/dstream/ConstantInputDStream.scala | 36 - .../spark/streaming/dstream/FileInputDStream.scala | 199 --- .../spark/streaming/dstream/FilteredDStream.scala | 38 - .../streaming/dstream/FlatMapValuedDStream.scala | 37 - .../streaming/dstream/FlatMappedDStream.scala | 37 - .../streaming/dstream/FlumeInputDStream.scala | 154 --- .../spark/streaming/dstream/ForEachDStream.scala | 45 - .../spark/streaming/dstream/GlommedDStream.scala | 34 - .../spark/streaming/dstream/InputDStream.scala | 70 -- .../streaming/dstream/KafkaInputDStream.scala | 141 --- .../streaming/dstream/MapPartitionedDStream.scala | 38 - .../spark/streaming/dstream/MapValuedDStream.scala | 38 - .../spark/streaming/dstream/MappedDStream.scala | 37 - .../streaming/dstream/NetworkInputDStream.scala | 272 ---- .../streaming/dstream/PluggableInputDStream.scala | 30 - .../streaming/dstream/QueueInputDStream.scala | 59 - .../spark/streaming/dstream/RawInputDStream.scala | 108 -- .../streaming/dstream/ReducedWindowedDStream.scala | 174 --- .../spark/streaming/dstream/ShuffledDStream.scala | 44 - .../streaming/dstream/SocketInputDStream.scala | 94 -- .../spark/streaming/dstream/StateDStream.scala | 109 -- .../streaming/dstream/TransformedDStream.scala | 36 - .../streaming/dstream/TwitterInputDStream.scala | 99 -- .../spark/streaming/dstream/UnionDStream.scala | 57 - .../spark/streaming/dstream/WindowedDStream.scala | 57 - .../spark/streaming/receivers/ActorReceiver.scala | 175 --- .../spark/streaming/receivers/ZeroMQReceiver.scala | 50 - .../main/scala/spark/streaming/util/Clock.scala | 101 -- .../spark/streaming/util/MasterFailureTest.scala | 414 ------- .../scala/spark/streaming/util/RawTextHelper.scala | 115 -- .../scala/spark/streaming/util/RawTextSender.scala | 77 -- .../spark/streaming/util/RecurringTimer.scala | 94 -- .../org/apache/spark/streaming/JavaAPISuite.java | 1304 ++++++++++++++++++++ .../org/apache/spark/streaming/JavaTestUtils.scala | 85 ++ .../test/java/spark/streaming/JavaAPISuite.java | 1304 -------------------- .../test/java/spark/streaming/JavaTestUtils.scala | 84 -- .../spark/streaming/BasicOperationsSuite.scala | 322 +++++ .../apache/spark/streaming/CheckpointSuite.scala | 372 ++++++ .../org/apache/spark/streaming/FailureSuite.scala | 57 + .../apache/spark/streaming/InputStreamsSuite.scala | 349 ++++++ .../org/apache/spark/streaming/TestSuiteBase.scala | 314 +++++ .../spark/streaming/WindowOperationsSuite.scala | 340 +++++ .../spark/streaming/BasicOperationsSuite.scala | 322 ----- .../scala/spark/streaming/CheckpointSuite.scala | 372 ------ .../test/scala/spark/streaming/FailureSuite.scala | 57 - .../scala/spark/streaming/InputStreamsSuite.scala | 349 ------ .../test/scala/spark/streaming/TestSuiteBase.scala | 314 ----- .../spark/streaming/WindowOperationsSuite.scala | 340 ----- tools/pom.xml | 8 +- .../spark/tools/JavaAPICompletenessChecker.scala | 360 ++++++ .../spark/tools/JavaAPICompletenessChecker.scala | 360 ------ yarn/pom.xml | 6 +- .../spark/deploy/yarn/ApplicationMaster.scala | 371 ++++++ .../deploy/yarn/ApplicationMasterArguments.scala | 94 ++ .../org/apache/spark/deploy/yarn/Client.scala | 336 +++++ .../apache/spark/deploy/yarn/ClientArguments.scala | 116 ++ .../apache/spark/deploy/yarn/WorkerRunnable.scala | 224 ++++ .../spark/deploy/yarn/YarnAllocationHandler.scala | 564 +++++++++ .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 46 + .../scheduler/cluster/YarnClusterScheduler.scala | 52 + .../spark/deploy/yarn/ApplicationMaster.scala | 371 ------ .../deploy/yarn/ApplicationMasterArguments.scala | 94 -- yarn/src/main/scala/spark/deploy/yarn/Client.scala | 336 ----- .../scala/spark/deploy/yarn/ClientArguments.scala | 116 -- .../scala/spark/deploy/yarn/WorkerRunnable.scala | 224 ---- .../spark/deploy/yarn/YarnAllocationHandler.scala | 564 --------- .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 46 - .../scheduler/cluster/YarnClusterScheduler.scala | 52 - 1015 files changed, 70352 insertions(+), 70341 deletions(-) create mode 100644 bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala delete mode 100644 bagel/src/main/scala/spark/bagel/Bagel.scala delete mode 100644 bagel/src/test/scala/bagel/BagelSuite.scala create mode 100644 bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileClient.java create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileServer.java create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java create mode 100755 core/src/main/java/org/apache/spark/network/netty/PathResolver.java delete mode 100644 core/src/main/java/spark/network/netty/FileClient.java delete mode 100644 core/src/main/java/spark/network/netty/FileClientChannelInitializer.java delete mode 100644 core/src/main/java/spark/network/netty/FileClientHandler.java delete mode 100644 core/src/main/java/spark/network/netty/FileServer.java delete mode 100644 core/src/main/java/spark/network/netty/FileServerChannelInitializer.java delete mode 100644 core/src/main/java/spark/network/netty/FileServerHandler.java delete mode 100755 core/src/main/java/spark/network/netty/PathResolver.java create mode 100755 core/src/main/resources/org/apache/spark/ui/static/bootstrap.min.css create mode 100644 core/src/main/resources/org/apache/spark/ui/static/sorttable.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/spark-logo-77x50px-hd.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/spark_logo.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/webui.css delete mode 100755 core/src/main/resources/spark/ui/static/bootstrap.min.css delete mode 100644 core/src/main/resources/spark/ui/static/sorttable.js delete mode 100644 core/src/main/resources/spark/ui/static/spark-logo-77x50px-hd.png delete mode 100644 core/src/main/resources/spark/ui/static/spark_logo.png delete mode 100644 core/src/main/resources/spark/ui/static/webui.css create mode 100644 core/src/main/scala/org/apache/spark/Accumulators.scala create mode 100644 core/src/main/scala/org/apache/spark/Aggregator.scala create mode 100644 core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala create mode 100644 core/src/main/scala/org/apache/spark/CacheManager.scala create mode 100644 core/src/main/scala/org/apache/spark/ClosureCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/Dependency.scala create mode 100644 core/src/main/scala/org/apache/spark/DoubleRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/FetchFailedException.scala create mode 100644 core/src/main/scala/org/apache/spark/HttpFileServer.scala create mode 100644 core/src/main/scala/org/apache/spark/HttpServer.scala create mode 100644 core/src/main/scala/org/apache/spark/JavaSerializer.scala create mode 100644 core/src/main/scala/org/apache/spark/KryoSerializer.scala create mode 100644 core/src/main/scala/org/apache/spark/Logging.scala create mode 100644 core/src/main/scala/org/apache/spark/MapOutputTracker.scala create mode 100644 core/src/main/scala/org/apache/spark/PairRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/Partition.scala create mode 100644 core/src/main/scala/org/apache/spark/Partitioner.scala create mode 100644 core/src/main/scala/org/apache/spark/RDD.scala create mode 100644 core/src/main/scala/org/apache/spark/RDDCheckpointData.scala create mode 100644 core/src/main/scala/org/apache/spark/SequenceFileRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/SerializableWritable.scala create mode 100644 core/src/main/scala/org/apache/spark/ShuffleFetcher.scala create mode 100644 core/src/main/scala/org/apache/spark/SizeEstimator.scala create mode 100644 core/src/main/scala/org/apache/spark/SparkContext.scala create mode 100644 core/src/main/scala/org/apache/spark/SparkEnv.scala create mode 100644 core/src/main/scala/org/apache/spark/SparkException.scala create mode 100644 core/src/main/scala/org/apache/spark/SparkFiles.java create mode 100644 core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala create mode 100644 core/src/main/scala/org/apache/spark/TaskContext.scala create mode 100644 core/src/main/scala/org/apache/spark/TaskEndReason.scala create mode 100644 core/src/main/scala/org/apache/spark/TaskState.scala create mode 100644 core/src/main/scala/org/apache/spark/Utils.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/StorageLevels.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/Function.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/Function2.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala create mode 100644 core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala create mode 100644 core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/Command.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/WebUI.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/client/Client.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/Master.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/Executor.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala create mode 100644 core/src/main/scala/org/apache/spark/io/CompressionCodec.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/source/Source.scala create mode 100644 core/src/main/scala/org/apache/spark/network/BufferMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/network/Connection.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ConnectionManager.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala create mode 100644 core/src/main/scala/org/apache/spark/network/Message.scala create mode 100644 core/src/main/scala/org/apache/spark/network/MessageChunk.scala create mode 100644 core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ReceiverTest.scala create mode 100644 core/src/main/scala/org/apache/spark/network/SenderTest.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala create mode 100644 core/src/main/scala/org/apache/spark/package.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/ApproximateEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/PartialResult.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/JobListener.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/JobResult.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/Stage.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/Task.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/serializer/Serializer.scala create mode 100644 core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockException.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManager.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockStore.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/DiskStore.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/MemoryStore.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/PutResult.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/StorageLevel.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/StorageUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/JettyUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/Page.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/SparkUI.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/UIUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala create mode 100644 core/src/main/scala/org/apache/spark/util/AkkaUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala create mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala create mode 100644 core/src/main/scala/org/apache/spark/util/Clock.scala create mode 100644 core/src/main/scala/org/apache/spark/util/CompletionIterator.scala create mode 100644 core/src/main/scala/org/apache/spark/util/Distribution.scala create mode 100644 core/src/main/scala/org/apache/spark/util/IdGenerator.scala create mode 100644 core/src/main/scala/org/apache/spark/util/IntParam.scala create mode 100644 core/src/main/scala/org/apache/spark/util/MemoryParam.scala create mode 100644 core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/util/MutablePair.scala create mode 100644 core/src/main/scala/org/apache/spark/util/NextIterator.scala create mode 100644 core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala create mode 100644 core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala create mode 100644 core/src/main/scala/org/apache/spark/util/StatCounter.scala create mode 100644 core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala create mode 100644 core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala create mode 100644 core/src/main/scala/org/apache/spark/util/Vector.scala delete mode 100644 core/src/main/scala/spark/Accumulators.scala delete mode 100644 core/src/main/scala/spark/Aggregator.scala delete mode 100644 core/src/main/scala/spark/BlockStoreShuffleFetcher.scala delete mode 100644 core/src/main/scala/spark/CacheManager.scala delete mode 100644 core/src/main/scala/spark/ClosureCleaner.scala delete mode 100644 core/src/main/scala/spark/Dependency.scala delete mode 100644 core/src/main/scala/spark/DoubleRDDFunctions.scala delete mode 100644 core/src/main/scala/spark/FetchFailedException.scala delete mode 100644 core/src/main/scala/spark/HttpFileServer.scala delete mode 100644 core/src/main/scala/spark/HttpServer.scala delete mode 100644 core/src/main/scala/spark/JavaSerializer.scala delete mode 100644 core/src/main/scala/spark/KryoSerializer.scala delete mode 100644 core/src/main/scala/spark/Logging.scala delete mode 100644 core/src/main/scala/spark/MapOutputTracker.scala delete mode 100644 core/src/main/scala/spark/PairRDDFunctions.scala delete mode 100644 core/src/main/scala/spark/Partition.scala delete mode 100644 core/src/main/scala/spark/Partitioner.scala delete mode 100644 core/src/main/scala/spark/RDD.scala delete mode 100644 core/src/main/scala/spark/RDDCheckpointData.scala delete mode 100644 core/src/main/scala/spark/SequenceFileRDDFunctions.scala delete mode 100644 core/src/main/scala/spark/SerializableWritable.scala delete mode 100644 core/src/main/scala/spark/ShuffleFetcher.scala delete mode 100644 core/src/main/scala/spark/SizeEstimator.scala delete mode 100644 core/src/main/scala/spark/SparkContext.scala delete mode 100644 core/src/main/scala/spark/SparkEnv.scala delete mode 100644 core/src/main/scala/spark/SparkException.scala delete mode 100644 core/src/main/scala/spark/SparkFiles.java delete mode 100644 core/src/main/scala/spark/SparkHadoopWriter.scala delete mode 100644 core/src/main/scala/spark/TaskContext.scala delete mode 100644 core/src/main/scala/spark/TaskEndReason.scala delete mode 100644 core/src/main/scala/spark/TaskState.scala delete mode 100644 core/src/main/scala/spark/Utils.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaDoubleRDD.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaPairRDD.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaRDD.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaRDDLike.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaSparkContext.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java delete mode 100644 core/src/main/scala/spark/api/java/JavaUtils.scala delete mode 100644 core/src/main/scala/spark/api/java/StorageLevels.java delete mode 100644 core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java delete mode 100644 core/src/main/scala/spark/api/java/function/DoubleFunction.java delete mode 100644 core/src/main/scala/spark/api/java/function/FlatMapFunction.scala delete mode 100644 core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala delete mode 100644 core/src/main/scala/spark/api/java/function/Function.java delete mode 100644 core/src/main/scala/spark/api/java/function/Function2.java delete mode 100644 core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java delete mode 100644 core/src/main/scala/spark/api/java/function/PairFunction.java delete mode 100644 core/src/main/scala/spark/api/java/function/VoidFunction.scala delete mode 100644 core/src/main/scala/spark/api/java/function/WrappedFunction1.scala delete mode 100644 core/src/main/scala/spark/api/java/function/WrappedFunction2.scala delete mode 100644 core/src/main/scala/spark/api/python/PythonPartitioner.scala delete mode 100644 core/src/main/scala/spark/api/python/PythonRDD.scala delete mode 100644 core/src/main/scala/spark/api/python/PythonWorkerFactory.scala delete mode 100644 core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala delete mode 100644 core/src/main/scala/spark/broadcast/Broadcast.scala delete mode 100644 core/src/main/scala/spark/broadcast/BroadcastFactory.scala delete mode 100644 core/src/main/scala/spark/broadcast/HttpBroadcast.scala delete mode 100644 core/src/main/scala/spark/broadcast/MultiTracker.scala delete mode 100644 core/src/main/scala/spark/broadcast/SourceInfo.scala delete mode 100644 core/src/main/scala/spark/broadcast/TreeBroadcast.scala delete mode 100644 core/src/main/scala/spark/deploy/ApplicationDescription.scala delete mode 100644 core/src/main/scala/spark/deploy/Command.scala delete mode 100644 core/src/main/scala/spark/deploy/DeployMessage.scala delete mode 100644 core/src/main/scala/spark/deploy/ExecutorState.scala delete mode 100644 core/src/main/scala/spark/deploy/JsonProtocol.scala delete mode 100644 core/src/main/scala/spark/deploy/LocalSparkCluster.scala delete mode 100644 core/src/main/scala/spark/deploy/SparkHadoopUtil.scala delete mode 100644 core/src/main/scala/spark/deploy/WebUI.scala delete mode 100644 core/src/main/scala/spark/deploy/client/Client.scala delete mode 100644 core/src/main/scala/spark/deploy/client/ClientListener.scala delete mode 100644 core/src/main/scala/spark/deploy/client/TestClient.scala delete mode 100644 core/src/main/scala/spark/deploy/client/TestExecutor.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ApplicationInfo.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ApplicationSource.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ApplicationState.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ExecutorInfo.scala delete mode 100644 core/src/main/scala/spark/deploy/master/Master.scala delete mode 100644 core/src/main/scala/spark/deploy/master/MasterArguments.scala delete mode 100644 core/src/main/scala/spark/deploy/master/MasterSource.scala delete mode 100644 core/src/main/scala/spark/deploy/master/WorkerInfo.scala delete mode 100644 core/src/main/scala/spark/deploy/master/WorkerState.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ui/ApplicationPage.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ui/IndexPage.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ui/MasterWebUI.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/Worker.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/WorkerArguments.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/WorkerSource.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/ui/IndexPage.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/ui/WorkerWebUI.scala delete mode 100644 core/src/main/scala/spark/executor/Executor.scala delete mode 100644 core/src/main/scala/spark/executor/ExecutorBackend.scala delete mode 100644 core/src/main/scala/spark/executor/ExecutorExitCode.scala delete mode 100644 core/src/main/scala/spark/executor/ExecutorSource.scala delete mode 100644 core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala delete mode 100644 core/src/main/scala/spark/executor/MesosExecutorBackend.scala delete mode 100644 core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala delete mode 100644 core/src/main/scala/spark/executor/TaskMetrics.scala delete mode 100644 core/src/main/scala/spark/io/CompressionCodec.scala delete mode 100644 core/src/main/scala/spark/metrics/MetricsConfig.scala delete mode 100644 core/src/main/scala/spark/metrics/MetricsSystem.scala delete mode 100644 core/src/main/scala/spark/metrics/sink/ConsoleSink.scala delete mode 100644 core/src/main/scala/spark/metrics/sink/CsvSink.scala delete mode 100644 core/src/main/scala/spark/metrics/sink/JmxSink.scala delete mode 100644 core/src/main/scala/spark/metrics/sink/MetricsServlet.scala delete mode 100644 core/src/main/scala/spark/metrics/sink/Sink.scala delete mode 100644 core/src/main/scala/spark/metrics/source/JvmSource.scala delete mode 100644 core/src/main/scala/spark/metrics/source/Source.scala delete mode 100644 core/src/main/scala/spark/network/BufferMessage.scala delete mode 100644 core/src/main/scala/spark/network/Connection.scala delete mode 100644 core/src/main/scala/spark/network/ConnectionManager.scala delete mode 100644 core/src/main/scala/spark/network/ConnectionManagerId.scala delete mode 100644 core/src/main/scala/spark/network/ConnectionManagerTest.scala delete mode 100644 core/src/main/scala/spark/network/Message.scala delete mode 100644 core/src/main/scala/spark/network/MessageChunk.scala delete mode 100644 core/src/main/scala/spark/network/MessageChunkHeader.scala delete mode 100644 core/src/main/scala/spark/network/ReceiverTest.scala delete mode 100644 core/src/main/scala/spark/network/SenderTest.scala delete mode 100644 core/src/main/scala/spark/network/netty/FileHeader.scala delete mode 100644 core/src/main/scala/spark/network/netty/ShuffleCopier.scala delete mode 100644 core/src/main/scala/spark/network/netty/ShuffleSender.scala delete mode 100644 core/src/main/scala/spark/package.scala delete mode 100644 core/src/main/scala/spark/partial/ApproximateActionListener.scala delete mode 100644 core/src/main/scala/spark/partial/ApproximateEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/BoundedDouble.scala delete mode 100644 core/src/main/scala/spark/partial/CountEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/GroupedCountEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/GroupedSumEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/MeanEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/PartialResult.scala delete mode 100644 core/src/main/scala/spark/partial/StudentTCacher.scala delete mode 100644 core/src/main/scala/spark/partial/SumEvaluator.scala delete mode 100644 core/src/main/scala/spark/rdd/BlockRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/CartesianRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/CheckpointRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/CoGroupedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/CoalescedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/EmptyRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/FilteredRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/FlatMappedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/GlommedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/HadoopRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/JdbcRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/MapPartitionsRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/MappedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/MappedValuesRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/NewHadoopRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala delete mode 100644 core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/PartitionPruningRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/PipedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/SampledRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/ShuffledRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/SubtractedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/UnionRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/ZippedRDD.scala delete mode 100644 core/src/main/scala/spark/scheduler/ActiveJob.scala delete mode 100644 core/src/main/scala/spark/scheduler/DAGScheduler.scala delete mode 100644 core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala delete mode 100644 core/src/main/scala/spark/scheduler/DAGSchedulerSource.scala delete mode 100644 core/src/main/scala/spark/scheduler/InputFormatInfo.scala delete mode 100644 core/src/main/scala/spark/scheduler/JobListener.scala delete mode 100644 core/src/main/scala/spark/scheduler/JobLogger.scala delete mode 100644 core/src/main/scala/spark/scheduler/JobResult.scala delete mode 100644 core/src/main/scala/spark/scheduler/JobWaiter.scala delete mode 100644 core/src/main/scala/spark/scheduler/MapStatus.scala delete mode 100644 core/src/main/scala/spark/scheduler/ResultTask.scala delete mode 100644 core/src/main/scala/spark/scheduler/ShuffleMapTask.scala delete mode 100644 core/src/main/scala/spark/scheduler/SparkListener.scala delete mode 100644 core/src/main/scala/spark/scheduler/SparkListenerBus.scala delete mode 100644 core/src/main/scala/spark/scheduler/SplitInfo.scala delete mode 100644 core/src/main/scala/spark/scheduler/Stage.scala delete mode 100644 core/src/main/scala/spark/scheduler/StageInfo.scala delete mode 100644 core/src/main/scala/spark/scheduler/Task.scala delete mode 100644 core/src/main/scala/spark/scheduler/TaskLocation.scala delete mode 100644 core/src/main/scala/spark/scheduler/TaskResult.scala delete mode 100644 core/src/main/scala/spark/scheduler/TaskScheduler.scala delete mode 100644 core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala delete mode 100644 core/src/main/scala/spark/scheduler/TaskSet.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/Pool.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/Schedulable.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala delete mode 100644 core/src/main/scala/spark/scheduler/local/LocalScheduler.scala delete mode 100644 core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala delete mode 100644 core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala delete mode 100644 core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala delete mode 100644 core/src/main/scala/spark/serializer/Serializer.scala delete mode 100644 core/src/main/scala/spark/serializer/SerializerManager.scala delete mode 100644 core/src/main/scala/spark/storage/BlockException.scala delete mode 100644 core/src/main/scala/spark/storage/BlockFetchTracker.scala delete mode 100644 core/src/main/scala/spark/storage/BlockFetcherIterator.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManager.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerId.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerMaster.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerMasterActor.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerMessages.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerSource.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerWorker.scala delete mode 100644 core/src/main/scala/spark/storage/BlockMessage.scala delete mode 100644 core/src/main/scala/spark/storage/BlockMessageArray.scala delete mode 100644 core/src/main/scala/spark/storage/BlockObjectWriter.scala delete mode 100644 core/src/main/scala/spark/storage/BlockStore.scala delete mode 100644 core/src/main/scala/spark/storage/DiskStore.scala delete mode 100644 core/src/main/scala/spark/storage/MemoryStore.scala delete mode 100644 core/src/main/scala/spark/storage/PutResult.scala delete mode 100644 core/src/main/scala/spark/storage/ShuffleBlockManager.scala delete mode 100644 core/src/main/scala/spark/storage/StorageLevel.scala delete mode 100644 core/src/main/scala/spark/storage/StorageUtils.scala delete mode 100644 core/src/main/scala/spark/storage/ThreadingTest.scala delete mode 100644 core/src/main/scala/spark/ui/JettyUtils.scala delete mode 100644 core/src/main/scala/spark/ui/Page.scala delete mode 100644 core/src/main/scala/spark/ui/SparkUI.scala delete mode 100644 core/src/main/scala/spark/ui/UIUtils.scala delete mode 100644 core/src/main/scala/spark/ui/UIWorkloadGenerator.scala delete mode 100644 core/src/main/scala/spark/ui/env/EnvironmentUI.scala delete mode 100644 core/src/main/scala/spark/ui/exec/ExecutorsUI.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/IndexPage.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/JobProgressListener.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/JobProgressUI.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/PoolPage.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/PoolTable.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/StagePage.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/StageTable.scala delete mode 100644 core/src/main/scala/spark/ui/storage/BlockManagerUI.scala delete mode 100644 core/src/main/scala/spark/ui/storage/IndexPage.scala delete mode 100644 core/src/main/scala/spark/ui/storage/RDDPage.scala delete mode 100644 core/src/main/scala/spark/util/AkkaUtils.scala delete mode 100644 core/src/main/scala/spark/util/BoundedPriorityQueue.scala delete mode 100644 core/src/main/scala/spark/util/ByteBufferInputStream.scala delete mode 100644 core/src/main/scala/spark/util/Clock.scala delete mode 100644 core/src/main/scala/spark/util/CompletionIterator.scala delete mode 100644 core/src/main/scala/spark/util/Distribution.scala delete mode 100644 core/src/main/scala/spark/util/IdGenerator.scala delete mode 100644 core/src/main/scala/spark/util/IntParam.scala delete mode 100644 core/src/main/scala/spark/util/MemoryParam.scala delete mode 100644 core/src/main/scala/spark/util/MetadataCleaner.scala delete mode 100644 core/src/main/scala/spark/util/MutablePair.scala delete mode 100644 core/src/main/scala/spark/util/NextIterator.scala delete mode 100644 core/src/main/scala/spark/util/RateLimitedOutputStream.scala delete mode 100644 core/src/main/scala/spark/util/SerializableBuffer.scala delete mode 100644 core/src/main/scala/spark/util/StatCounter.scala delete mode 100644 core/src/main/scala/spark/util/TimeStampedHashMap.scala delete mode 100644 core/src/main/scala/spark/util/TimeStampedHashSet.scala delete mode 100644 core/src/main/scala/spark/util/Vector.scala create mode 100644 core/src/test/scala/org/apache/spark/AccumulatorSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/BroadcastSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/CheckpointSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/DistributedSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/DriverSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/FailureSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/FileServerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/FileSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/JavaAPISuite.java create mode 100644 core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/LocalSparkContext.scala create mode 100644 core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/PartitioningSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/PipedRDDSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/RDDSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/SharedSparkContext.scala create mode 100644 core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ShuffleSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/SortingSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ThreadingSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/UnpersistSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/UtilsSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ui/UISuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/DistributionSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/FakeClock.scala create mode 100644 core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala delete mode 100644 core/src/test/scala/spark/AccumulatorSuite.scala delete mode 100644 core/src/test/scala/spark/BroadcastSuite.scala delete mode 100644 core/src/test/scala/spark/CheckpointSuite.scala delete mode 100644 core/src/test/scala/spark/ClosureCleanerSuite.scala delete mode 100644 core/src/test/scala/spark/DistributedSuite.scala delete mode 100644 core/src/test/scala/spark/DriverSuite.scala delete mode 100644 core/src/test/scala/spark/FailureSuite.scala delete mode 100644 core/src/test/scala/spark/FileServerSuite.scala delete mode 100644 core/src/test/scala/spark/FileSuite.scala delete mode 100644 core/src/test/scala/spark/JavaAPISuite.java delete mode 100644 core/src/test/scala/spark/KryoSerializerSuite.scala delete mode 100644 core/src/test/scala/spark/LocalSparkContext.scala delete mode 100644 core/src/test/scala/spark/MapOutputTrackerSuite.scala delete mode 100644 core/src/test/scala/spark/PairRDDFunctionsSuite.scala delete mode 100644 core/src/test/scala/spark/PartitionPruningRDDSuite.scala delete mode 100644 core/src/test/scala/spark/PartitioningSuite.scala delete mode 100644 core/src/test/scala/spark/PipedRDDSuite.scala delete mode 100644 core/src/test/scala/spark/RDDSuite.scala delete mode 100644 core/src/test/scala/spark/SharedSparkContext.scala delete mode 100644 core/src/test/scala/spark/ShuffleNettySuite.scala delete mode 100644 core/src/test/scala/spark/ShuffleSuite.scala delete mode 100644 core/src/test/scala/spark/SizeEstimatorSuite.scala delete mode 100644 core/src/test/scala/spark/SortingSuite.scala delete mode 100644 core/src/test/scala/spark/SparkContextInfoSuite.scala delete mode 100644 core/src/test/scala/spark/ThreadingSuite.scala delete mode 100644 core/src/test/scala/spark/UnpersistSuite.scala delete mode 100644 core/src/test/scala/spark/UtilsSuite.scala delete mode 100644 core/src/test/scala/spark/ZippedPartitionsSuite.scala delete mode 100644 core/src/test/scala/spark/io/CompressionCodecSuite.scala delete mode 100644 core/src/test/scala/spark/metrics/MetricsConfigSuite.scala delete mode 100644 core/src/test/scala/spark/metrics/MetricsSystemSuite.scala delete mode 100644 core/src/test/scala/spark/rdd/JdbcRDDSuite.scala delete mode 100644 core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/JobLoggerSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/SparkListenerSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/TaskContextSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/cluster/FakeTask.scala delete mode 100644 core/src/test/scala/spark/scheduler/local/LocalSchedulerSuite.scala delete mode 100644 core/src/test/scala/spark/storage/BlockManagerSuite.scala delete mode 100644 core/src/test/scala/spark/ui/UISuite.scala delete mode 100644 core/src/test/scala/spark/util/DistributionSuite.scala delete mode 100644 core/src/test/scala/spark/util/FakeClock.scala delete mode 100644 core/src/test/scala/spark/util/NextIteratorSuite.scala delete mode 100644 core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaKMeans.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaPageRank.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaTC.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaWordCount.java create mode 100644 examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java create mode 100644 examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java create mode 100644 examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java create mode 100644 examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java create mode 100644 examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java create mode 100644 examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java delete mode 100644 examples/src/main/java/spark/examples/JavaHdfsLR.java delete mode 100644 examples/src/main/java/spark/examples/JavaKMeans.java delete mode 100644 examples/src/main/java/spark/examples/JavaLogQuery.java delete mode 100644 examples/src/main/java/spark/examples/JavaPageRank.java delete mode 100644 examples/src/main/java/spark/examples/JavaSparkPi.java delete mode 100644 examples/src/main/java/spark/examples/JavaTC.java delete mode 100644 examples/src/main/java/spark/examples/JavaWordCount.java delete mode 100644 examples/src/main/java/spark/mllib/examples/JavaALS.java delete mode 100644 examples/src/main/java/spark/mllib/examples/JavaKMeans.java delete mode 100644 examples/src/main/java/spark/mllib/examples/JavaLR.java delete mode 100644 examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java delete mode 100644 examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java delete mode 100644 examples/src/main/java/spark/streaming/examples/JavaQueueStream.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LocalALS.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LocalLR.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LocalPi.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LogQuery.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkALS.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkLR.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkPi.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkTC.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala delete mode 100644 examples/src/main/scala/spark/examples/BroadcastTest.scala delete mode 100644 examples/src/main/scala/spark/examples/CassandraTest.scala delete mode 100644 examples/src/main/scala/spark/examples/ExceptionHandlingTest.scala delete mode 100644 examples/src/main/scala/spark/examples/GroupByTest.scala delete mode 100644 examples/src/main/scala/spark/examples/HBaseTest.scala delete mode 100644 examples/src/main/scala/spark/examples/HdfsTest.scala delete mode 100644 examples/src/main/scala/spark/examples/LocalALS.scala delete mode 100644 examples/src/main/scala/spark/examples/LocalFileLR.scala delete mode 100644 examples/src/main/scala/spark/examples/LocalKMeans.scala delete mode 100644 examples/src/main/scala/spark/examples/LocalLR.scala delete mode 100644 examples/src/main/scala/spark/examples/LocalPi.scala delete mode 100644 examples/src/main/scala/spark/examples/LogQuery.scala delete mode 100644 examples/src/main/scala/spark/examples/MultiBroadcastTest.scala delete mode 100644 examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala delete mode 100644 examples/src/main/scala/spark/examples/SkewedGroupByTest.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkALS.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkHdfsLR.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkKMeans.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkLR.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkPageRank.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkPi.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkTC.scala delete mode 100644 examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala delete mode 100644 examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala delete mode 100644 examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/QueueStream.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala delete mode 100644 mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala delete mode 100644 mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala delete mode 100644 mllib/src/main/scala/spark/mllib/classification/SVM.scala delete mode 100644 mllib/src/main/scala/spark/mllib/clustering/KMeans.scala delete mode 100644 mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala delete mode 100644 mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala delete mode 100644 mllib/src/main/scala/spark/mllib/optimization/Gradient.scala delete mode 100644 mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala delete mode 100644 mllib/src/main/scala/spark/mllib/optimization/Optimizer.scala delete mode 100644 mllib/src/main/scala/spark/mllib/optimization/Updater.scala delete mode 100644 mllib/src/main/scala/spark/mllib/recommendation/ALS.scala delete mode 100644 mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/Lasso.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/RegressionModel.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/DataValidators.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/MFDataGenerator.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/MLUtils.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala create mode 100644 mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/classification/JavaLogisticRegressionSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/classification/JavaSVMSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/clustering/JavaKMeansSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/recommendation/JavaALSSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/Main.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkISettings.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkImports.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala delete mode 100644 repl/src/main/scala/spark/repl/ExecutorClassLoader.scala delete mode 100644 repl/src/main/scala/spark/repl/Main.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkHelper.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkILoop.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkIMain.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkISettings.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkImports.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkJLineCompletion.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkJLineReader.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkMemberHandlers.scala create mode 100644 repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala delete mode 100644 repl/src/test/scala/spark/repl/ReplSuite.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/DStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Duration.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Interval.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Job.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Time.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/TwitterInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Checkpoint.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DStreamGraph.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Duration.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Interval.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Job.scala delete mode 100644 streaming/src/main/scala/spark/streaming/JobManager.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala delete mode 100644 streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Scheduler.scala delete mode 100644 streaming/src/main/scala/spark/streaming/StreamingContext.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Time.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala delete mode 100644 streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/Clock.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/RawTextSender.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala delete mode 100644 streaming/src/test/java/spark/streaming/JavaAPISuite.java delete mode 100644 streaming/src/test/java/spark/streaming/JavaTestUtils.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/CheckpointSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/FailureSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/TestSuiteBase.scala delete mode 100644 streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala create mode 100644 tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala delete mode 100644 tools/src/main/scala/spark/tools/JavaAPICompletenessChecker.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala create mode 100644 yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/ApplicationMaster.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/ApplicationMasterArguments.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/Client.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/ClientArguments.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/WorkerRunnable.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/YarnAllocationHandler.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/YarnSparkHadoopUtil.scala delete mode 100644 yarn/src/main/scala/spark/scheduler/cluster/YarnClusterScheduler.scala (limited to 'README.md') diff --git a/README.md b/README.md index 2ddfe862a2..c4170650f7 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Or, for the Python API, the Python shell (`./pyspark`). Spark also comes with several sample programs in the `examples` directory. To run one of them, use `./run-example `. For example: - ./run-example spark.examples.SparkLR local[2] + ./run-example org.apache.spark.examples.SparkLR local[2] will run the Logistic Regression example locally on 2 CPUs. diff --git a/assembly/pom.xml b/assembly/pom.xml index 74990b6361..dc63811b76 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-assembly Spark Project Assembly http://spark-project.org/ @@ -40,27 +40,27 @@ - org.spark-project + org.apache.spark spark-core ${project.version} - org.spark-project + org.apache.spark spark-bagel ${project.version} - org.spark-project + org.apache.spark spark-mllib ${project.version} - org.spark-project + org.apache.spark spark-repl ${project.version} - org.spark-project + org.apache.spark spark-streaming ${project.version} @@ -121,7 +121,7 @@ hadoop2-yarn - org.spark-project + org.apache.spark spark-yarn ${project.version} diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index 4543b52c93..47d3fa93d0 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -30,9 +30,9 @@ - ${project.parent.basedir}/core/src/main/resources/spark/ui/static/ + ${project.parent.basedir}/core/src/main/resources/org/apache/spark/ui/static/ - /ui-resources/spark/ui/static + /ui-resources/org/apache/spark/ui/static **/* @@ -63,10 +63,10 @@ - org.spark-project:*:jar + org.apache.spark:*:jar - org.spark-project:spark-assembly:jar + org.apache.spark:spark-assembly:jar @@ -77,7 +77,7 @@ false org.apache.hadoop:*:jar - org.spark-project:*:jar + org.apache.spark:*:jar diff --git a/bagel/pom.xml b/bagel/pom.xml index cbcf8d1239..9340991377 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-bagel jar Spark Project Bagel @@ -33,7 +33,7 @@ - org.spark-project + org.apache.spark spark-core ${project.version} diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala new file mode 100644 index 0000000000..fec8737fcd --- /dev/null +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.bagel + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +import org.apache.spark.storage.StorageLevel + +object Bagel extends Logging { + val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK + + /** + * Runs a Bagel program. + * @param sc [[org.apache.spark.SparkContext]] to use for the program. + * @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the Key will be + * the vertex id. + * @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often this will be an + * empty array, i.e. sc.parallelize(Array[K, Message]()). + * @param combiner [[org.apache.spark.bagel.Combiner]] combines multiple individual messages to a given vertex into one + * message before sending (which often involves network I/O). + * @param aggregator [[org.apache.spark.bagel.Aggregator]] performs a reduce across all vertices after each superstep, + * and provides the result to each vertex in the next superstep. + * @param partitioner [[org.apache.spark.Partitioner]] partitions values by key + * @param numPartitions number of partitions across which to split the graph. + * Default is the default parallelism of the SparkContext + * @param storageLevel [[org.apache.spark.storage.StorageLevel]] to use for caching of intermediate RDDs in each superstep. + * Defaults to caching in memory. + * @param compute function that takes a Vertex, optional set of (possibly combined) messages to the Vertex, + * optional Aggregator and the current superstep, + * and returns a set of (Vertex, outgoing Messages) pairs + * @tparam K key + * @tparam V vertex type + * @tparam M message type + * @tparam C combiner + * @tparam A aggregator + * @return an RDD of (K, V) pairs representing the graph after completion of the program + */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, + C: Manifest, A: Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + aggregator: Option[Aggregator[V, A]], + partitioner: Partitioner, + numPartitions: Int, + storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL + )( + compute: (V, Option[C], Option[A], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism + + var superstep = 0 + var verts = vertices + var msgs = messages + var noActivity = false + do { + logInfo("Starting superstep "+superstep+".") + val startTime = System.currentTimeMillis + + val aggregated = agg(verts, aggregator) + val combinedMsgs = msgs.combineByKey( + combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) + val grouped = combinedMsgs.groupWith(verts) + val superstep_ = superstep // Create a read-only copy of superstep for capture in closure + val (processed, numMsgs, numActiveVerts) = + comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel) + + val timeTaken = System.currentTimeMillis - startTime + logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) + + verts = processed.mapValues { case (vert, msgs) => vert } + msgs = processed.flatMap { + case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) + } + superstep += 1 + + noActivity = numMsgs == 0 && numActiveVerts == 0 + } while (!noActivity) + + verts + } + + /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the default storage level */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + partitioner: Partitioner, + numPartitions: Int + )( + compute: (V, Option[C], Int) => (V, Array[M]) + ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) + + /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + partitioner: Partitioner, + numPartitions: Int, + storageLevel: StorageLevel + )( + compute: (V, Option[C], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + run[K, V, M, C, Nothing]( + sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)( + addAggregatorArg[K, V, M, C](compute)) + } + + /** + * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default [[org.apache.spark.HashPartitioner]] + * and default storage level + */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + numPartitions: Int + )( + compute: (V, Option[C], Int) => (V, Array[M]) + ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) + + /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the default [[org.apache.spark.HashPartitioner]]*/ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + numPartitions: Int, + storageLevel: StorageLevel + )( + compute: (V, Option[C], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + val part = new HashPartitioner(numPartitions) + run[K, V, M, C, Nothing]( + sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)( + addAggregatorArg[K, V, M, C](compute)) + } + + /** + * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default [[org.apache.spark.HashPartitioner]], + * [[org.apache.spark.bagel.DefaultCombiner]] and the default storage level + */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + numPartitions: Int + )( + compute: (V, Option[Array[M]], Int) => (V, Array[M]) + ): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) + + /** + * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], the default [[org.apache.spark.HashPartitioner]] + * and [[org.apache.spark.bagel.DefaultCombiner]] + */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + numPartitions: Int, + storageLevel: StorageLevel + )( + compute: (V, Option[Array[M]], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + val part = new HashPartitioner(numPartitions) + run[K, V, M, Array[M], Nothing]( + sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)( + addAggregatorArg[K, V, M, Array[M]](compute)) + } + + /** + * Aggregates the given vertices using the given aggregator, if it + * is specified. + */ + private def agg[K, V <: Vertex, A: Manifest]( + verts: RDD[(K, V)], + aggregator: Option[Aggregator[V, A]] + ): Option[A] = aggregator match { + case Some(a) => + Some(verts.map { + case (id, vert) => a.createAggregator(vert) + }.reduce(a.mergeAggregators(_, _))) + case None => None + } + + /** + * Processes the given vertex-message RDD using the compute + * function. Returns the processed RDD, the number of messages + * created, and the number of active vertices. + */ + private def comp[K: Manifest, V <: Vertex, M <: Message[K], C]( + sc: SparkContext, + grouped: RDD[(K, (Seq[C], Seq[V]))], + compute: (V, Option[C]) => (V, Array[M]), + storageLevel: StorageLevel + ): (RDD[(K, (V, Array[M]))], Int, Int) = { + var numMsgs = sc.accumulator(0) + var numActiveVerts = sc.accumulator(0) + val processed = grouped.flatMapValues { + case (_, vs) if vs.size == 0 => None + case (c, vs) => + val (newVert, newMsgs) = + compute(vs(0), c match { + case Seq(comb) => Some(comb) + case Seq() => None + }) + + numMsgs += newMsgs.size + if (newVert.active) + numActiveVerts += 1 + + Some((newVert, newMsgs)) + }.persist(storageLevel) + + // Force evaluation of processed RDD for accurate performance measurements + processed.foreach(x => {}) + + (processed, numMsgs.value, numActiveVerts.value) + } + + /** + * Converts a compute function that doesn't take an aggregator to + * one that does, so it can be passed to Bagel.run. + */ + private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C]( + compute: (V, Option[C], Int) => (V, Array[M]) + ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = { + (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) => + compute(vert, msgs, superstep) + } +} + +trait Combiner[M, C] { + def createCombiner(msg: M): C + def mergeMsg(combiner: C, msg: M): C + def mergeCombiners(a: C, b: C): C +} + +trait Aggregator[V, A] { + def createAggregator(vert: V): A + def mergeAggregators(a: A, b: A): A +} + +/** Default combiner that simply appends messages together (i.e. performs no aggregation) */ +class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { + def createCombiner(msg: M): Array[M] = + Array(msg) + def mergeMsg(combiner: Array[M], msg: M): Array[M] = + combiner :+ msg + def mergeCombiners(a: Array[M], b: Array[M]): Array[M] = + a ++ b +} + +/** + * Represents a Bagel vertex. + * + * Subclasses may store state along with each vertex and must + * inherit from java.io.Serializable or scala.Serializable. + */ +trait Vertex { + def active: Boolean +} + +/** + * Represents a Bagel message to a target vertex. + * + * Subclasses may contain a payload to deliver to the target vertex + * and must inherit from java.io.Serializable or scala.Serializable. + */ +trait Message[K] { + def targetId: K +} diff --git a/bagel/src/main/scala/spark/bagel/Bagel.scala b/bagel/src/main/scala/spark/bagel/Bagel.scala deleted file mode 100644 index 80c8d53d2b..0000000000 --- a/bagel/src/main/scala/spark/bagel/Bagel.scala +++ /dev/null @@ -1,294 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.bagel - -import spark._ -import spark.SparkContext._ - -import scala.collection.mutable.ArrayBuffer -import storage.StorageLevel - -object Bagel extends Logging { - val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK - - /** - * Runs a Bagel program. - * @param sc [[spark.SparkContext]] to use for the program. - * @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the Key will be - * the vertex id. - * @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often this will be an - * empty array, i.e. sc.parallelize(Array[K, Message]()). - * @param combiner [[spark.bagel.Combiner]] combines multiple individual messages to a given vertex into one - * message before sending (which often involves network I/O). - * @param aggregator [[spark.bagel.Aggregator]] performs a reduce across all vertices after each superstep, - * and provides the result to each vertex in the next superstep. - * @param partitioner [[spark.Partitioner]] partitions values by key - * @param numPartitions number of partitions across which to split the graph. - * Default is the default parallelism of the SparkContext - * @param storageLevel [[spark.storage.StorageLevel]] to use for caching of intermediate RDDs in each superstep. - * Defaults to caching in memory. - * @param compute function that takes a Vertex, optional set of (possibly combined) messages to the Vertex, - * optional Aggregator and the current superstep, - * and returns a set of (Vertex, outgoing Messages) pairs - * @tparam K key - * @tparam V vertex type - * @tparam M message type - * @tparam C combiner - * @tparam A aggregator - * @return an RDD of (K, V) pairs representing the graph after completion of the program - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, - C: Manifest, A: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - aggregator: Option[Aggregator[V, A]], - partitioner: Partitioner, - numPartitions: Int, - storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL - )( - compute: (V, Option[C], Option[A], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism - - var superstep = 0 - var verts = vertices - var msgs = messages - var noActivity = false - do { - logInfo("Starting superstep "+superstep+".") - val startTime = System.currentTimeMillis - - val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKey( - combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) - val grouped = combinedMsgs.groupWith(verts) - val superstep_ = superstep // Create a read-only copy of superstep for capture in closure - val (processed, numMsgs, numActiveVerts) = - comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel) - - val timeTaken = System.currentTimeMillis - startTime - logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) - - verts = processed.mapValues { case (vert, msgs) => vert } - msgs = processed.flatMap { - case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) - } - superstep += 1 - - noActivity = numMsgs == 0 && numActiveVerts == 0 - } while (!noActivity) - - verts - } - - /** Runs a Bagel program with no [[spark.bagel.Aggregator]] and the default storage level */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - partitioner: Partitioner, - numPartitions: Int - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** Runs a Bagel program with no [[spark.bagel.Aggregator]] */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - partitioner: Partitioner, - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, C](compute)) - } - - /** - * Runs a Bagel program with no [[spark.bagel.Aggregator]], default [[spark.HashPartitioner]] - * and default storage level - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - numPartitions: Int - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** Runs a Bagel program with no [[spark.bagel.Aggregator]] and the default [[spark.HashPartitioner]]*/ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val part = new HashPartitioner(numPartitions) - run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, C](compute)) - } - - /** - * Runs a Bagel program with no [[spark.bagel.Aggregator]], default [[spark.HashPartitioner]], - * [[spark.bagel.DefaultCombiner]] and the default storage level - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - numPartitions: Int - )( - compute: (V, Option[Array[M]], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** - * Runs a Bagel program with no [[spark.bagel.Aggregator]], the default [[spark.HashPartitioner]] - * and [[spark.bagel.DefaultCombiner]] - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[Array[M]], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val part = new HashPartitioner(numPartitions) - run[K, V, M, Array[M], Nothing]( - sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, Array[M]](compute)) - } - - /** - * Aggregates the given vertices using the given aggregator, if it - * is specified. - */ - private def agg[K, V <: Vertex, A: Manifest]( - verts: RDD[(K, V)], - aggregator: Option[Aggregator[V, A]] - ): Option[A] = aggregator match { - case Some(a) => - Some(verts.map { - case (id, vert) => a.createAggregator(vert) - }.reduce(a.mergeAggregators(_, _))) - case None => None - } - - /** - * Processes the given vertex-message RDD using the compute - * function. Returns the processed RDD, the number of messages - * created, and the number of active vertices. - */ - private def comp[K: Manifest, V <: Vertex, M <: Message[K], C]( - sc: SparkContext, - grouped: RDD[(K, (Seq[C], Seq[V]))], - compute: (V, Option[C]) => (V, Array[M]), - storageLevel: StorageLevel - ): (RDD[(K, (V, Array[M]))], Int, Int) = { - var numMsgs = sc.accumulator(0) - var numActiveVerts = sc.accumulator(0) - val processed = grouped.flatMapValues { - case (_, vs) if vs.size == 0 => None - case (c, vs) => - val (newVert, newMsgs) = - compute(vs(0), c match { - case Seq(comb) => Some(comb) - case Seq() => None - }) - - numMsgs += newMsgs.size - if (newVert.active) - numActiveVerts += 1 - - Some((newVert, newMsgs)) - }.persist(storageLevel) - - // Force evaluation of processed RDD for accurate performance measurements - processed.foreach(x => {}) - - (processed, numMsgs.value, numActiveVerts.value) - } - - /** - * Converts a compute function that doesn't take an aggregator to - * one that does, so it can be passed to Bagel.run. - */ - private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C]( - compute: (V, Option[C], Int) => (V, Array[M]) - ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = { - (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) => - compute(vert, msgs, superstep) - } -} - -trait Combiner[M, C] { - def createCombiner(msg: M): C - def mergeMsg(combiner: C, msg: M): C - def mergeCombiners(a: C, b: C): C -} - -trait Aggregator[V, A] { - def createAggregator(vert: V): A - def mergeAggregators(a: A, b: A): A -} - -/** Default combiner that simply appends messages together (i.e. performs no aggregation) */ -class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { - def createCombiner(msg: M): Array[M] = - Array(msg) - def mergeMsg(combiner: Array[M], msg: M): Array[M] = - combiner :+ msg - def mergeCombiners(a: Array[M], b: Array[M]): Array[M] = - a ++ b -} - -/** - * Represents a Bagel vertex. - * - * Subclasses may store state along with each vertex and must - * inherit from java.io.Serializable or scala.Serializable. - */ -trait Vertex { - def active: Boolean -} - -/** - * Represents a Bagel message to a target vertex. - * - * Subclasses may contain a payload to deliver to the target vertex - * and must inherit from java.io.Serializable or scala.Serializable. - */ -trait Message[K] { - def targetId: K -} diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala deleted file mode 100644 index ef2d57fbd0..0000000000 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.bagel - -import org.scalatest.{FunSuite, Assertions, BeforeAndAfter} -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import scala.collection.mutable.ArrayBuffer - -import spark._ -import storage.StorageLevel - -class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable -class TestMessage(val targetId: String) extends Message[String] with Serializable - -class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - - test("halting by voting") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("halting by message silence") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) - val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - val msgsOut = - msgs match { - case Some(ms) if (superstep < numSupersteps - 1) => - ms - case _ => - Array[TestMessage]() - } - (new TestVertex(self.active, self.age + 1), msgsOut) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("large number of iterations") { - // This tests whether jobs with a large number of iterations finish in a reasonable time, - // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } - - test("using non-default persistence level") { - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } -} diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala new file mode 100644 index 0000000000..7b954a4775 --- /dev/null +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.bagel + +import org.scalatest.{BeforeAndAfter, FunSuite, Assertions} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.storage.StorageLevel + +class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable +class TestMessage(val targetId: String) extends Message[String] with Serializable + +class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { + + var sc: SparkContext = _ + + after { + if (sc != null) { + sc.stop() + sc = null + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + test("halting by voting") { + sc = new SparkContext("local", "test") + val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) + val msgs = sc.parallelize(Array[(String, TestMessage)]()) + val numSupersteps = 5 + val result = + Bagel.run(sc, verts, msgs, sc.defaultParallelism) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) + } + for ((id, vert) <- result.collect) { + assert(vert.age === numSupersteps) + } + } + + test("halting by message silence") { + sc = new SparkContext("local", "test") + val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) + val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) + val numSupersteps = 5 + val result = + Bagel.run(sc, verts, msgs, sc.defaultParallelism) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + val msgsOut = + msgs match { + case Some(ms) if (superstep < numSupersteps - 1) => + ms + case _ => + Array[TestMessage]() + } + (new TestVertex(self.active, self.age + 1), msgsOut) + } + for ((id, vert) <- result.collect) { + assert(vert.age === numSupersteps) + } + } + + test("large number of iterations") { + // This tests whether jobs with a large number of iterations finish in a reasonable time, + // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang + failAfter(10 seconds) { + sc = new SparkContext("local", "test") + val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) + val msgs = sc.parallelize(Array[(String, TestMessage)]()) + val numSupersteps = 50 + val result = + Bagel.run(sc, verts, msgs, sc.defaultParallelism) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) + } + for ((id, vert) <- result.collect) { + assert(vert.age === numSupersteps) + } + } + } + + test("using non-default persistence level") { + failAfter(10 seconds) { + sc = new SparkContext("local", "test") + val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) + val msgs = sc.parallelize(Array[(String, TestMessage)]()) + val numSupersteps = 50 + val result = + Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) + } + for ((id, vert) <- result.collect) { + assert(vert.age === numSupersteps) + } + } + } +} diff --git a/bin/start-master.sh b/bin/start-master.sh index 2288fb19d7..648c7ae75f 100755 --- a/bin/start-master.sh +++ b/bin/start-master.sh @@ -49,4 +49,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then fi fi -"$bin"/spark-daemon.sh start spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT +"$bin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT diff --git a/bin/start-slave.sh b/bin/start-slave.sh index d6db16882d..4eefa20944 100755 --- a/bin/start-slave.sh +++ b/bin/start-slave.sh @@ -32,4 +32,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then fi fi -"$bin"/spark-daemon.sh start spark.deploy.worker.Worker "$@" +"$bin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker "$@" diff --git a/bin/stop-master.sh b/bin/stop-master.sh index 31a610bf9d..310e33bedc 100755 --- a/bin/stop-master.sh +++ b/bin/stop-master.sh @@ -24,4 +24,4 @@ bin=`cd "$bin"; pwd` . "$bin/spark-config.sh" -"$bin"/spark-daemon.sh stop spark.deploy.master.Master 1 +"$bin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1 diff --git a/bin/stop-slaves.sh b/bin/stop-slaves.sh index 8e056f23d4..03e416a132 100755 --- a/bin/stop-slaves.sh +++ b/bin/stop-slaves.sh @@ -29,9 +29,9 @@ if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then fi if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - "$bin"/spark-daemons.sh stop spark.deploy.worker.Worker 1 + "$bin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker 1 else for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$bin"/spark-daemons.sh stop spark.deploy.worker.Worker $(( $i + 1 )) + "$bin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) done fi diff --git a/core/pom.xml b/core/pom.xml index 53696367e9..c803217f96 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-core jar Spark Project Core diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java new file mode 100644 index 0000000000..20a7a3aa8c --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileClient.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelOption; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioSocketChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class FileClient { + + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); + private FileClientHandler handler = null; + private Channel channel = null; + private Bootstrap bootstrap = null; + private int connectTimeout = 60*1000; // 1 min + + public FileClient(FileClientHandler handler, int connectTimeout) { + this.handler = handler; + this.connectTimeout = connectTimeout; + } + + public void init() { + bootstrap = new Bootstrap(); + bootstrap.group(new OioEventLoopGroup()) + .channel(OioSocketChannel.class) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) + .handler(new FileClientChannelInitializer(handler)); + } + + public void connect(String host, int port) { + try { + // Start the connection attempt. + channel = bootstrap.connect(host, port).sync().channel(); + // ChannelFuture cf = channel.closeFuture(); + //cf.addListener(new ChannelCloseListener(this)); + } catch (InterruptedException e) { + close(); + } + } + + public void waitForClose() { + try { + channel.closeFuture().sync(); + } catch (InterruptedException e) { + LOG.warn("FileClient interrupted", e); + } + } + + public void sendRequest(String file) { + //assert(file == null); + //assert(channel == null); + channel.write(file + "\r\n"); + } + + public void close() { + if(channel != null) { + channel.close(); + channel = null; + } + if ( bootstrap!=null) { + bootstrap.shutdown(); + bootstrap = null; + } + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java new file mode 100644 index 0000000000..65ee15d63b --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import io.netty.buffer.BufType; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.string.StringEncoder; + + +class FileClientChannelInitializer extends ChannelInitializer { + + private FileClientHandler fhandler; + + public FileClientChannelInitializer(FileClientHandler handler) { + fhandler = handler; + } + + @Override + public void initChannel(SocketChannel channel) { + // file no more than 2G + channel.pipeline() + .addLast("encoder", new StringEncoder(BufType.BYTE)) + .addLast("handler", fhandler); + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java new file mode 100644 index 0000000000..c4aa2669e0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundByteHandlerAdapter; + + +abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { + + private FileHeader currentHeader = null; + + private volatile boolean handlerCalled = false; + + public boolean isComplete() { + return handlerCalled; + } + + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + public abstract void handleError(String blockId); + + @Override + public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { + // Use direct buffer if possible. + return ctx.alloc().ioBuffer(); + } + + @Override + public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { + // get header + if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); + } + // get file + if(in.readableBytes() >= currentHeader.fileLen()) { + handle(ctx, in, currentHeader); + handlerCalled = true; + currentHeader = null; + ctx.close(); + } + } + +} + diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java new file mode 100644 index 0000000000..666432474d --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileServer.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import java.net.InetSocketAddress; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelOption; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioServerSocketChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Server that accept the path of a file an echo back its content. + */ +class FileServer { + + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); + + private ServerBootstrap bootstrap = null; + private ChannelFuture channelFuture = null; + private int port = 0; + private Thread blockingThread = null; + + public FileServer(PathResolver pResolver, int port) { + InetSocketAddress addr = new InetSocketAddress(port); + + // Configure the server. + bootstrap = new ServerBootstrap(); + bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + .channel(OioServerSocketChannel.class) + .option(ChannelOption.SO_BACKLOG, 100) + .option(ChannelOption.SO_RCVBUF, 1500) + .childHandler(new FileServerChannelInitializer(pResolver)); + // Start the server. + channelFuture = bootstrap.bind(addr); + try { + // Get the address we bound to. + InetSocketAddress boundAddress = + ((InetSocketAddress) channelFuture.sync().channel().localAddress()); + this.port = boundAddress.getPort(); + } catch (InterruptedException ie) { + this.port = 0; + } + } + + /** + * Start the file server asynchronously in a new thread. + */ + public void start() { + blockingThread = new Thread() { + public void run() { + try { + channelFuture.channel().closeFuture().sync(); + LOG.info("FileServer exiting"); + } catch (InterruptedException e) { + LOG.error("File server start got interrupted", e); + } + // NOTE: bootstrap is shutdown in stop() + } + }; + blockingThread.setDaemon(true); + blockingThread.start(); + } + + public int getPort() { + return port; + } + + public void stop() { + // Close the bound channel. + if (channelFuture != null) { + channelFuture.channel().close(); + channelFuture = null; + } + // Shutdown bootstrap. + if (bootstrap != null) { + bootstrap.shutdown(); + bootstrap = null; + } + // TODO: Shutdown all accepted channels as well ? + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java new file mode 100644 index 0000000000..833af1632d --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.handler.codec.string.StringDecoder; + + +class FileServerChannelInitializer extends ChannelInitializer { + + PathResolver pResolver; + + public FileServerChannelInitializer(PathResolver pResolver) { + this.pResolver = pResolver; + } + + @Override + public void initChannel(SocketChannel channel) { + channel.pipeline() + .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) + .addLast("strDecoder", new StringDecoder()) + .addLast("handler", new FileServerHandler(pResolver)); + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java new file mode 100644 index 0000000000..d3d57a0255 --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import java.io.File; +import java.io.FileInputStream; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundMessageHandlerAdapter; +import io.netty.channel.DefaultFileRegion; + + +class FileServerHandler extends ChannelInboundMessageHandlerAdapter { + + PathResolver pResolver; + + public FileServerHandler(PathResolver pResolver){ + this.pResolver = pResolver; + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, String blockId) { + String path = pResolver.getAbsolutePath(blockId); + // if getFilePath returns null, close the channel + if (path == null) { + //ctx.close(); + return; + } + File file = new File(path); + if (file.exists()) { + if (!file.isFile()) { + //logger.info("Not a file : " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + long length = file.length(); + if (length > Integer.MAX_VALUE || length <= 0) { + //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + int len = new Long(length).intValue(); + //logger.info("Sending block "+blockId+" filelen = "+len); + //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); + ctx.write((new FileHeader(len, blockId)).buffer()); + try { + ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) + .getChannel(), 0, file.length())); + } catch (Exception e) { + //logger.warning("Exception when sending file : " + file.getAbsolutePath()); + e.printStackTrace(); + } + } else { + //logger.warning("File not found: " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + } + ctx.flush(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ctx.close(); + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java new file mode 100755 index 0000000000..94c034cad0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + + +public interface PathResolver { + /** + * Get the absolute path of the file + * + * @param fileId + * @return the absolute path of file + */ + public String getAbsolutePath(String fileId); +} diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java deleted file mode 100644 index 0625a6d502..0000000000 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelOption; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioSocketChannel; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -class FileClient { - - private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); - private FileClientHandler handler = null; - private Channel channel = null; - private Bootstrap bootstrap = null; - private int connectTimeout = 60*1000; // 1 min - - public FileClient(FileClientHandler handler, int connectTimeout) { - this.handler = handler; - this.connectTimeout = connectTimeout; - } - - public void init() { - bootstrap = new Bootstrap(); - bootstrap.group(new OioEventLoopGroup()) - .channel(OioSocketChannel.class) - .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) - .handler(new FileClientChannelInitializer(handler)); - } - - public void connect(String host, int port) { - try { - // Start the connection attempt. - channel = bootstrap.connect(host, port).sync().channel(); - // ChannelFuture cf = channel.closeFuture(); - //cf.addListener(new ChannelCloseListener(this)); - } catch (InterruptedException e) { - close(); - } - } - - public void waitForClose() { - try { - channel.closeFuture().sync(); - } catch (InterruptedException e) { - LOG.warn("FileClient interrupted", e); - } - } - - public void sendRequest(String file) { - //assert(file == null); - //assert(channel == null); - channel.write(file + "\r\n"); - } - - public void close() { - if(channel != null) { - channel.close(); - channel = null; - } - if ( bootstrap!=null) { - bootstrap.shutdown(); - bootstrap = null; - } - } -} diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java deleted file mode 100644 index 05ad4b61d7..0000000000 --- a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import io.netty.buffer.BufType; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.string.StringEncoder; - - -class FileClientChannelInitializer extends ChannelInitializer { - - private FileClientHandler fhandler; - - public FileClientChannelInitializer(FileClientHandler handler) { - fhandler = handler; - } - - @Override - public void initChannel(SocketChannel channel) { - // file no more than 2G - channel.pipeline() - .addLast("encoder", new StringEncoder(BufType.BYTE)) - .addLast("handler", fhandler); - } -} diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java deleted file mode 100644 index e8cd9801f6..0000000000 --- a/core/src/main/java/spark/network/netty/FileClientHandler.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundByteHandlerAdapter; - - -abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { - - private FileHeader currentHeader = null; - - private volatile boolean handlerCalled = false; - - public boolean isComplete() { - return handlerCalled; - } - - public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); - public abstract void handleError(String blockId); - - @Override - public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { - // Use direct buffer if possible. - return ctx.alloc().ioBuffer(); - } - - @Override - public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { - // get header - if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { - currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); - } - // get file - if(in.readableBytes() >= currentHeader.fileLen()) { - handle(ctx, in, currentHeader); - handlerCalled = true; - currentHeader = null; - ctx.close(); - } - } - -} - diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java deleted file mode 100644 index 9f009a61d5..0000000000 --- a/core/src/main/java/spark/network/netty/FileServer.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import java.net.InetSocketAddress; - -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelOption; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioServerSocketChannel; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -/** - * Server that accept the path of a file an echo back its content. - */ -class FileServer { - - private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); - - private ServerBootstrap bootstrap = null; - private ChannelFuture channelFuture = null; - private int port = 0; - private Thread blockingThread = null; - - public FileServer(PathResolver pResolver, int port) { - InetSocketAddress addr = new InetSocketAddress(port); - - // Configure the server. - bootstrap = new ServerBootstrap(); - bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) - .channel(OioServerSocketChannel.class) - .option(ChannelOption.SO_BACKLOG, 100) - .option(ChannelOption.SO_RCVBUF, 1500) - .childHandler(new FileServerChannelInitializer(pResolver)); - // Start the server. - channelFuture = bootstrap.bind(addr); - try { - // Get the address we bound to. - InetSocketAddress boundAddress = - ((InetSocketAddress) channelFuture.sync().channel().localAddress()); - this.port = boundAddress.getPort(); - } catch (InterruptedException ie) { - this.port = 0; - } - } - - /** - * Start the file server asynchronously in a new thread. - */ - public void start() { - blockingThread = new Thread() { - public void run() { - try { - channelFuture.channel().closeFuture().sync(); - LOG.info("FileServer exiting"); - } catch (InterruptedException e) { - LOG.error("File server start got interrupted", e); - } - // NOTE: bootstrap is shutdown in stop() - } - }; - blockingThread.setDaemon(true); - blockingThread.start(); - } - - public int getPort() { - return port; - } - - public void stop() { - // Close the bound channel. - if (channelFuture != null) { - channelFuture.channel().close(); - channelFuture = null; - } - // Shutdown bootstrap. - if (bootstrap != null) { - bootstrap.shutdown(); - bootstrap = null; - } - // TODO: Shutdown all accepted channels as well ? - } -} diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java deleted file mode 100644 index 50c57a81a3..0000000000 --- a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.DelimiterBasedFrameDecoder; -import io.netty.handler.codec.Delimiters; -import io.netty.handler.codec.string.StringDecoder; - - -class FileServerChannelInitializer extends ChannelInitializer { - - PathResolver pResolver; - - public FileServerChannelInitializer(PathResolver pResolver) { - this.pResolver = pResolver; - } - - @Override - public void initChannel(SocketChannel channel) { - channel.pipeline() - .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) - .addLast("strDecoder", new StringDecoder()) - .addLast("handler", new FileServerHandler(pResolver)); - } -} diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java deleted file mode 100644 index 176ba8da49..0000000000 --- a/core/src/main/java/spark/network/netty/FileServerHandler.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import java.io.File; -import java.io.FileInputStream; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundMessageHandlerAdapter; -import io.netty.channel.DefaultFileRegion; - - -class FileServerHandler extends ChannelInboundMessageHandlerAdapter { - - PathResolver pResolver; - - public FileServerHandler(PathResolver pResolver){ - this.pResolver = pResolver; - } - - @Override - public void messageReceived(ChannelHandlerContext ctx, String blockId) { - String path = pResolver.getAbsolutePath(blockId); - // if getFilePath returns null, close the channel - if (path == null) { - //ctx.close(); - return; - } - File file = new File(path); - if (file.exists()) { - if (!file.isFile()) { - //logger.info("Not a file : " + file.getAbsolutePath()); - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - long length = file.length(); - if (length > Integer.MAX_VALUE || length <= 0) { - //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - int len = new Long(length).intValue(); - //logger.info("Sending block "+blockId+" filelen = "+len); - //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); - ctx.write((new FileHeader(len, blockId)).buffer()); - try { - ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) - .getChannel(), 0, file.length())); - } catch (Exception e) { - //logger.warning("Exception when sending file : " + file.getAbsolutePath()); - e.printStackTrace(); - } - } else { - //logger.warning("File not found: " + file.getAbsolutePath()); - ctx.write(new FileHeader(0, blockId).buffer()); - } - ctx.flush(); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - cause.printStackTrace(); - ctx.close(); - } -} diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java deleted file mode 100755 index f446c55b19..0000000000 --- a/core/src/main/java/spark/network/netty/PathResolver.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - - -public interface PathResolver { - /** - * Get the absolute path of the file - * - * @param fileId - * @return the absolute path of file - */ - public String getAbsolutePath(String fileId); -} diff --git a/core/src/main/resources/org/apache/spark/ui/static/bootstrap.min.css b/core/src/main/resources/org/apache/spark/ui/static/bootstrap.min.css new file mode 100755 index 0000000000..13cef3d6f1 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/bootstrap.min.css @@ -0,0 +1,874 @@ +/*! + * Bootstrap v2.3.2 + * + * Copyright 2013 Twitter, Inc + * Licensed under the Apache License v2.0 + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Designed and built with all the love in the world @twitter by @mdo and @fat. + */ +.clearfix{*zoom:1;}.clearfix:before,.clearfix:after{display:table;content:"";line-height:0;} +.clearfix:after{clear:both;} +.hide-text{font:0/0 a;color:transparent;text-shadow:none;background-color:transparent;border:0;} +.input-block-level{display:block;width:100%;min-height:30px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;} +article,aside,details,figcaption,figure,footer,header,hgroup,nav,section{display:block;} +audio,canvas,video{display:inline-block;*display:inline;*zoom:1;} +audio:not([controls]){display:none;} +html{font-size:100%;-webkit-text-size-adjust:100%;-ms-text-size-adjust:100%;} +a:focus{outline:thin dotted #333;outline:5px auto -webkit-focus-ring-color;outline-offset:-2px;} +a:hover,a:active{outline:0;} +sub,sup{position:relative;font-size:75%;line-height:0;vertical-align:baseline;} +sup{top:-0.5em;} +sub{bottom:-0.25em;} +img{max-width:100%;width:auto\9;height:auto;vertical-align:middle;border:0;-ms-interpolation-mode:bicubic;} +#map_canvas img,.google-maps img{max-width:none;} +button,input,select,textarea{margin:0;font-size:100%;vertical-align:middle;} +button,input{*overflow:visible;line-height:normal;} +button::-moz-focus-inner,input::-moz-focus-inner{padding:0;border:0;} +button,html input[type="button"],input[type="reset"],input[type="submit"]{-webkit-appearance:button;cursor:pointer;} +label,select,button,input[type="button"],input[type="reset"],input[type="submit"],input[type="radio"],input[type="checkbox"]{cursor:pointer;} +input[type="search"]{-webkit-box-sizing:content-box;-moz-box-sizing:content-box;box-sizing:content-box;-webkit-appearance:textfield;} +input[type="search"]::-webkit-search-decoration,input[type="search"]::-webkit-search-cancel-button{-webkit-appearance:none;} +textarea{overflow:auto;vertical-align:top;} +@media print{*{text-shadow:none !important;color:#000 !important;background:transparent !important;box-shadow:none !important;} a,a:visited{text-decoration:underline;} a[href]:after{content:" (" attr(href) ")";} abbr[title]:after{content:" (" attr(title) ")";} .ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:"";} pre,blockquote{border:1px solid #999;page-break-inside:avoid;} thead{display:table-header-group;} tr,img{page-break-inside:avoid;} img{max-width:100% !important;} @page {margin:0.5cm;}p,h2,h3{orphans:3;widows:3;} h2,h3{page-break-after:avoid;}}body{margin:0;font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;font-size:14px;line-height:20px;color:#333333;background-color:#ffffff;} +a{color:#0088cc;text-decoration:none;} +a:hover,a:focus{color:#005580;text-decoration:underline;} +.img-rounded{-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;} +.img-polaroid{padding:4px;background-color:#fff;border:1px solid #ccc;border:1px solid rgba(0, 0, 0, 0.2);-webkit-box-shadow:0 1px 3px rgba(0, 0, 0, 0.1);-moz-box-shadow:0 1px 3px rgba(0, 0, 0, 0.1);box-shadow:0 1px 3px rgba(0, 0, 0, 0.1);} +.img-circle{-webkit-border-radius:500px;-moz-border-radius:500px;border-radius:500px;} +.row{margin-left:-20px;*zoom:1;}.row:before,.row:after{display:table;content:"";line-height:0;} +.row:after{clear:both;} +[class*="span"]{float:left;min-height:1px;margin-left:20px;} +.container,.navbar-static-top .container,.navbar-fixed-top .container,.navbar-fixed-bottom .container{width:940px;} +.span12{width:940px;} +.span11{width:860px;} +.span10{width:780px;} +.span9{width:700px;} +.span8{width:620px;} +.span7{width:540px;} +.span6{width:460px;} +.span5{width:380px;} +.span4{width:300px;} +.span3{width:220px;} +.span2{width:140px;} +.span1{width:60px;} +.offset12{margin-left:980px;} +.offset11{margin-left:900px;} +.offset10{margin-left:820px;} +.offset9{margin-left:740px;} +.offset8{margin-left:660px;} +.offset7{margin-left:580px;} +.offset6{margin-left:500px;} +.offset5{margin-left:420px;} +.offset4{margin-left:340px;} +.offset3{margin-left:260px;} +.offset2{margin-left:180px;} +.offset1{margin-left:100px;} +.row-fluid{width:100%;*zoom:1;}.row-fluid:before,.row-fluid:after{display:table;content:"";line-height:0;} +.row-fluid:after{clear:both;} +.row-fluid [class*="span"]{display:block;width:100%;min-height:30px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;float:left;margin-left:2.127659574468085%;*margin-left:2.074468085106383%;} +.row-fluid [class*="span"]:first-child{margin-left:0;} +.row-fluid .controls-row [class*="span"]+[class*="span"]{margin-left:2.127659574468085%;} +.row-fluid .span12{width:100%;*width:99.94680851063829%;} +.row-fluid .span11{width:91.48936170212765%;*width:91.43617021276594%;} +.row-fluid .span10{width:82.97872340425532%;*width:82.92553191489361%;} +.row-fluid .span9{width:74.46808510638297%;*width:74.41489361702126%;} +.row-fluid .span8{width:65.95744680851064%;*width:65.90425531914893%;} +.row-fluid .span7{width:57.44680851063829%;*width:57.39361702127659%;} +.row-fluid .span6{width:48.93617021276595%;*width:48.88297872340425%;} +.row-fluid .span5{width:40.42553191489362%;*width:40.37234042553192%;} +.row-fluid .span4{width:31.914893617021278%;*width:31.861702127659576%;} +.row-fluid .span3{width:23.404255319148934%;*width:23.351063829787233%;} +.row-fluid .span2{width:14.893617021276595%;*width:14.840425531914894%;} +.row-fluid .span1{width:6.382978723404255%;*width:6.329787234042553%;} +.row-fluid .offset12{margin-left:104.25531914893617%;*margin-left:104.14893617021275%;} +.row-fluid .offset12:first-child{margin-left:102.12765957446808%;*margin-left:102.02127659574467%;} +.row-fluid .offset11{margin-left:95.74468085106382%;*margin-left:95.6382978723404%;} +.row-fluid .offset11:first-child{margin-left:93.61702127659574%;*margin-left:93.51063829787232%;} +.row-fluid .offset10{margin-left:87.23404255319149%;*margin-left:87.12765957446807%;} +.row-fluid .offset10:first-child{margin-left:85.1063829787234%;*margin-left:84.99999999999999%;} +.row-fluid .offset9{margin-left:78.72340425531914%;*margin-left:78.61702127659572%;} +.row-fluid .offset9:first-child{margin-left:76.59574468085106%;*margin-left:76.48936170212764%;} +.row-fluid .offset8{margin-left:70.2127659574468%;*margin-left:70.10638297872339%;} +.row-fluid .offset8:first-child{margin-left:68.08510638297872%;*margin-left:67.9787234042553%;} +.row-fluid .offset7{margin-left:61.70212765957446%;*margin-left:61.59574468085106%;} +.row-fluid .offset7:first-child{margin-left:59.574468085106375%;*margin-left:59.46808510638297%;} +.row-fluid .offset6{margin-left:53.191489361702125%;*margin-left:53.085106382978715%;} +.row-fluid .offset6:first-child{margin-left:51.063829787234035%;*margin-left:50.95744680851063%;} +.row-fluid .offset5{margin-left:44.68085106382979%;*margin-left:44.57446808510638%;} +.row-fluid .offset5:first-child{margin-left:42.5531914893617%;*margin-left:42.4468085106383%;} +.row-fluid .offset4{margin-left:36.170212765957444%;*margin-left:36.06382978723405%;} +.row-fluid .offset4:first-child{margin-left:34.04255319148936%;*margin-left:33.93617021276596%;} +.row-fluid .offset3{margin-left:27.659574468085104%;*margin-left:27.5531914893617%;} +.row-fluid .offset3:first-child{margin-left:25.53191489361702%;*margin-left:25.425531914893618%;} +.row-fluid .offset2{margin-left:19.148936170212764%;*margin-left:19.04255319148936%;} +.row-fluid .offset2:first-child{margin-left:17.02127659574468%;*margin-left:16.914893617021278%;} +.row-fluid .offset1{margin-left:10.638297872340425%;*margin-left:10.53191489361702%;} +.row-fluid .offset1:first-child{margin-left:8.51063829787234%;*margin-left:8.404255319148938%;} +[class*="span"].hide,.row-fluid [class*="span"].hide{display:none;} +[class*="span"].pull-right,.row-fluid [class*="span"].pull-right{float:right;} +.container{margin-right:auto;margin-left:auto;*zoom:1;}.container:before,.container:after{display:table;content:"";line-height:0;} +.container:after{clear:both;} +.container-fluid{padding-right:20px;padding-left:20px;*zoom:1;}.container-fluid:before,.container-fluid:after{display:table;content:"";line-height:0;} +.container-fluid:after{clear:both;} +p{margin:0 0 10px;} +.lead{margin-bottom:20px;font-size:21px;font-weight:200;line-height:30px;} +small{font-size:85%;} +strong{font-weight:bold;} +em{font-style:italic;} +cite{font-style:normal;} +.muted{color:#999999;} +a.muted:hover,a.muted:focus{color:#808080;} +.text-warning{color:#c09853;} +a.text-warning:hover,a.text-warning:focus{color:#a47e3c;} +.text-error{color:#b94a48;} +a.text-error:hover,a.text-error:focus{color:#953b39;} +.text-info{color:#3a87ad;} +a.text-info:hover,a.text-info:focus{color:#2d6987;} +.text-success{color:#468847;} +a.text-success:hover,a.text-success:focus{color:#356635;} +.text-left{text-align:left;} +.text-right{text-align:right;} +.text-center{text-align:center;} +h1,h2,h3,h4,h5,h6{margin:10px 0;font-family:inherit;font-weight:bold;line-height:20px;color:inherit;text-rendering:optimizelegibility;}h1 small,h2 small,h3 small,h4 small,h5 small,h6 small{font-weight:normal;line-height:1;color:#999999;} +h1,h2,h3{line-height:40px;} +h1{font-size:38.5px;} +h2{font-size:31.5px;} +h3{font-size:24.5px;} +h4{font-size:17.5px;} +h5{font-size:14px;} +h6{font-size:11.9px;} +h1 small{font-size:24.5px;} +h2 small{font-size:17.5px;} +h3 small{font-size:14px;} +h4 small{font-size:14px;} +.page-header{padding-bottom:9px;margin:20px 0 30px;border-bottom:1px solid #eeeeee;} +ul,ol{padding:0;margin:0 0 10px 25px;} +ul ul,ul ol,ol ol,ol ul{margin-bottom:0;} +li{line-height:20px;} +ul.unstyled,ol.unstyled{margin-left:0;list-style:none;} +ul.inline,ol.inline{margin-left:0;list-style:none;}ul.inline>li,ol.inline>li{display:inline-block;*display:inline;*zoom:1;padding-left:5px;padding-right:5px;} +dl{margin-bottom:20px;} +dt,dd{line-height:20px;} +dt{font-weight:bold;} +dd{margin-left:10px;} +.dl-horizontal{*zoom:1;}.dl-horizontal:before,.dl-horizontal:after{display:table;content:"";line-height:0;} +.dl-horizontal:after{clear:both;} +.dl-horizontal dt{float:left;width:160px;clear:left;text-align:right;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;} +.dl-horizontal dd{margin-left:180px;} +hr{margin:20px 0;border:0;border-top:1px solid #eeeeee;border-bottom:1px solid #ffffff;} +abbr[title],abbr[data-original-title]{cursor:help;border-bottom:1px dotted #999999;} +abbr.initialism{font-size:90%;text-transform:uppercase;} +blockquote{padding:0 0 0 15px;margin:0 0 20px;border-left:5px solid #eeeeee;}blockquote p{margin-bottom:0;font-size:17.5px;font-weight:300;line-height:1.25;} +blockquote small{display:block;line-height:20px;color:#999999;}blockquote small:before{content:'\2014 \00A0';} +blockquote.pull-right{float:right;padding-right:15px;padding-left:0;border-right:5px solid #eeeeee;border-left:0;}blockquote.pull-right p,blockquote.pull-right small{text-align:right;} +blockquote.pull-right small:before{content:'';} +blockquote.pull-right small:after{content:'\00A0 \2014';} +q:before,q:after,blockquote:before,blockquote:after{content:"";} +address{display:block;margin-bottom:20px;font-style:normal;line-height:20px;} +code,pre{padding:0 3px 2px;font-family:Monaco,Menlo,Consolas,"Courier New",monospace;font-size:12px;color:#333333;-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} +code{padding:2px 4px;color:#d14;background-color:#f7f7f9;border:1px solid #e1e1e8;white-space:nowrap;} +pre{display:block;padding:9.5px;margin:0 0 10px;font-size:13px;line-height:20px;word-break:break-all;word-wrap:break-word;white-space:pre;white-space:pre-wrap;background-color:#f5f5f5;border:1px solid #ccc;border:1px solid rgba(0, 0, 0, 0.15);-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;}pre.prettyprint{margin-bottom:20px;} +pre code{padding:0;color:inherit;white-space:pre;white-space:pre-wrap;background-color:transparent;border:0;} +.pre-scrollable{max-height:340px;overflow-y:scroll;} +.label,.badge{display:inline-block;padding:2px 4px;font-size:11.844px;font-weight:bold;line-height:14px;color:#ffffff;vertical-align:baseline;white-space:nowrap;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#999999;} +.label{-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} +.badge{padding-left:9px;padding-right:9px;-webkit-border-radius:9px;-moz-border-radius:9px;border-radius:9px;} +.label:empty,.badge:empty{display:none;} +a.label:hover,a.label:focus,a.badge:hover,a.badge:focus{color:#ffffff;text-decoration:none;cursor:pointer;} +.label-important,.badge-important{background-color:#b94a48;} +.label-important[href],.badge-important[href]{background-color:#953b39;} +.label-warning,.badge-warning{background-color:#f89406;} +.label-warning[href],.badge-warning[href]{background-color:#c67605;} +.label-success,.badge-success{background-color:#468847;} +.label-success[href],.badge-success[href]{background-color:#356635;} +.label-info,.badge-info{background-color:#3a87ad;} +.label-info[href],.badge-info[href]{background-color:#2d6987;} +.label-inverse,.badge-inverse{background-color:#333333;} +.label-inverse[href],.badge-inverse[href]{background-color:#1a1a1a;} +.btn .label,.btn .badge{position:relative;top:-1px;} +.btn-mini .label,.btn-mini .badge{top:0;} +table{max-width:100%;background-color:transparent;border-collapse:collapse;border-spacing:0;} +.table{width:100%;margin-bottom:20px;}.table th,.table td{padding:8px;line-height:20px;text-align:left;vertical-align:top;border-top:1px solid #dddddd;} +.table th{font-weight:bold;} +.table thead th{vertical-align:bottom;} +.table caption+thead tr:first-child th,.table caption+thead tr:first-child td,.table colgroup+thead tr:first-child th,.table colgroup+thead tr:first-child td,.table thead:first-child tr:first-child th,.table thead:first-child tr:first-child td{border-top:0;} +.table tbody+tbody{border-top:2px solid #dddddd;} +.table .table{background-color:#ffffff;} +.table-condensed th,.table-condensed td{padding:4px 5px;} +.table-bordered{border:1px solid #dddddd;border-collapse:separate;*border-collapse:collapse;border-left:0;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;}.table-bordered th,.table-bordered td{border-left:1px solid #dddddd;} +.table-bordered caption+thead tr:first-child th,.table-bordered caption+tbody tr:first-child th,.table-bordered caption+tbody tr:first-child td,.table-bordered colgroup+thead tr:first-child th,.table-bordered colgroup+tbody tr:first-child th,.table-bordered colgroup+tbody tr:first-child td,.table-bordered thead:first-child tr:first-child th,.table-bordered tbody:first-child tr:first-child th,.table-bordered tbody:first-child tr:first-child td{border-top:0;} +.table-bordered thead:first-child tr:first-child>th:first-child,.table-bordered tbody:first-child tr:first-child>td:first-child,.table-bordered tbody:first-child tr:first-child>th:first-child{-webkit-border-top-left-radius:4px;-moz-border-radius-topleft:4px;border-top-left-radius:4px;} +.table-bordered thead:first-child tr:first-child>th:last-child,.table-bordered tbody:first-child tr:first-child>td:last-child,.table-bordered tbody:first-child tr:first-child>th:last-child{-webkit-border-top-right-radius:4px;-moz-border-radius-topright:4px;border-top-right-radius:4px;} +.table-bordered thead:last-child tr:last-child>th:first-child,.table-bordered tbody:last-child tr:last-child>td:first-child,.table-bordered tbody:last-child tr:last-child>th:first-child,.table-bordered tfoot:last-child tr:last-child>td:first-child,.table-bordered tfoot:last-child tr:last-child>th:first-child{-webkit-border-bottom-left-radius:4px;-moz-border-radius-bottomleft:4px;border-bottom-left-radius:4px;} +.table-bordered thead:last-child tr:last-child>th:last-child,.table-bordered tbody:last-child tr:last-child>td:last-child,.table-bordered tbody:last-child tr:last-child>th:last-child,.table-bordered tfoot:last-child tr:last-child>td:last-child,.table-bordered tfoot:last-child tr:last-child>th:last-child{-webkit-border-bottom-right-radius:4px;-moz-border-radius-bottomright:4px;border-bottom-right-radius:4px;} +.table-bordered tfoot+tbody:last-child tr:last-child td:first-child{-webkit-border-bottom-left-radius:0;-moz-border-radius-bottomleft:0;border-bottom-left-radius:0;} +.table-bordered tfoot+tbody:last-child tr:last-child td:last-child{-webkit-border-bottom-right-radius:0;-moz-border-radius-bottomright:0;border-bottom-right-radius:0;} +.table-bordered caption+thead tr:first-child th:first-child,.table-bordered caption+tbody tr:first-child td:first-child,.table-bordered colgroup+thead tr:first-child th:first-child,.table-bordered colgroup+tbody tr:first-child td:first-child{-webkit-border-top-left-radius:4px;-moz-border-radius-topleft:4px;border-top-left-radius:4px;} +.table-bordered caption+thead tr:first-child th:last-child,.table-bordered caption+tbody tr:first-child td:last-child,.table-bordered colgroup+thead tr:first-child th:last-child,.table-bordered colgroup+tbody tr:first-child td:last-child{-webkit-border-top-right-radius:4px;-moz-border-radius-topright:4px;border-top-right-radius:4px;} +.table-striped tbody>tr:nth-child(odd)>td,.table-striped tbody>tr:nth-child(odd)>th{background-color:#f9f9f9;} +.table-hover tbody tr:hover>td,.table-hover tbody tr:hover>th{background-color:#f5f5f5;} +table td[class*="span"],table th[class*="span"],.row-fluid table td[class*="span"],.row-fluid table th[class*="span"]{display:table-cell;float:none;margin-left:0;} +.table td.span1,.table th.span1{float:none;width:44px;margin-left:0;} +.table td.span2,.table th.span2{float:none;width:124px;margin-left:0;} +.table td.span3,.table th.span3{float:none;width:204px;margin-left:0;} +.table td.span4,.table th.span4{float:none;width:284px;margin-left:0;} +.table td.span5,.table th.span5{float:none;width:364px;margin-left:0;} +.table td.span6,.table th.span6{float:none;width:444px;margin-left:0;} +.table td.span7,.table th.span7{float:none;width:524px;margin-left:0;} +.table td.span8,.table th.span8{float:none;width:604px;margin-left:0;} +.table td.span9,.table th.span9{float:none;width:684px;margin-left:0;} +.table td.span10,.table th.span10{float:none;width:764px;margin-left:0;} +.table td.span11,.table th.span11{float:none;width:844px;margin-left:0;} +.table td.span12,.table th.span12{float:none;width:924px;margin-left:0;} +.table tbody tr.success>td{background-color:#dff0d8;} +.table tbody tr.error>td{background-color:#f2dede;} +.table tbody tr.warning>td{background-color:#fcf8e3;} +.table tbody tr.info>td{background-color:#d9edf7;} +.table-hover tbody tr.success:hover>td{background-color:#d0e9c6;} +.table-hover tbody tr.error:hover>td{background-color:#ebcccc;} +.table-hover tbody tr.warning:hover>td{background-color:#faf2cc;} +.table-hover tbody tr.info:hover>td{background-color:#c4e3f3;} +form{margin:0 0 20px;} +fieldset{padding:0;margin:0;border:0;} +legend{display:block;width:100%;padding:0;margin-bottom:20px;font-size:21px;line-height:40px;color:#333333;border:0;border-bottom:1px solid #e5e5e5;}legend small{font-size:15px;color:#999999;} +label,input,button,select,textarea{font-size:14px;font-weight:normal;line-height:20px;} +input,button,select,textarea{font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;} +label{display:block;margin-bottom:5px;} +select,textarea,input[type="text"],input[type="password"],input[type="datetime"],input[type="datetime-local"],input[type="date"],input[type="month"],input[type="time"],input[type="week"],input[type="number"],input[type="email"],input[type="url"],input[type="search"],input[type="tel"],input[type="color"],.uneditable-input{display:inline-block;height:20px;padding:4px 6px;margin-bottom:10px;font-size:14px;line-height:20px;color:#555555;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;vertical-align:middle;} +input,textarea,.uneditable-input{width:206px;} +textarea{height:auto;} +textarea,input[type="text"],input[type="password"],input[type="datetime"],input[type="datetime-local"],input[type="date"],input[type="month"],input[type="time"],input[type="week"],input[type="number"],input[type="email"],input[type="url"],input[type="search"],input[type="tel"],input[type="color"],.uneditable-input{background-color:#ffffff;border:1px solid #cccccc;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-webkit-transition:border linear .2s, box-shadow linear .2s;-moz-transition:border linear .2s, box-shadow linear .2s;-o-transition:border linear .2s, box-shadow linear .2s;transition:border linear .2s, box-shadow linear .2s;}textarea:focus,input[type="text"]:focus,input[type="password"]:focus,input[type="datetime"]:focus,input[type="datetime-local"]:focus,input[type="date"]:focus,input[type="month"]:focus,input[type="time"]:focus,input[type="week"]:focus,input[type="number"]:focus,input[type="email"]:focus,input[type="url"]:focus,input[type="search"]:focus,input[type="tel"]:focus,input[type="color"]:focus,.uneditable-input:focus{border-color:rgba(82, 168, 236, 0.8);outline:0;outline:thin dotted \9;-webkit-box-shadow:inset 0 1px 1px rgba(0,0,0,.075), 0 0 8px rgba(82,168,236,.6);-moz-box-shadow:inset 0 1px 1px rgba(0,0,0,.075), 0 0 8px rgba(82,168,236,.6);box-shadow:inset 0 1px 1px rgba(0,0,0,.075), 0 0 8px rgba(82,168,236,.6);} +input[type="radio"],input[type="checkbox"]{margin:4px 0 0;*margin-top:0;margin-top:1px \9;line-height:normal;} +input[type="file"],input[type="image"],input[type="submit"],input[type="reset"],input[type="button"],input[type="radio"],input[type="checkbox"]{width:auto;} +select,input[type="file"]{height:30px;*margin-top:4px;line-height:30px;} +select{width:220px;border:1px solid #cccccc;background-color:#ffffff;} +select[multiple],select[size]{height:auto;} +select:focus,input[type="file"]:focus,input[type="radio"]:focus,input[type="checkbox"]:focus{outline:thin dotted #333;outline:5px auto -webkit-focus-ring-color;outline-offset:-2px;} +.uneditable-input,.uneditable-textarea{color:#999999;background-color:#fcfcfc;border-color:#cccccc;-webkit-box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.025);-moz-box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.025);box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.025);cursor:not-allowed;} +.uneditable-input{overflow:hidden;white-space:nowrap;} +.uneditable-textarea{width:auto;height:auto;} +input:-moz-placeholder,textarea:-moz-placeholder{color:#999999;} +input:-ms-input-placeholder,textarea:-ms-input-placeholder{color:#999999;} +input::-webkit-input-placeholder,textarea::-webkit-input-placeholder{color:#999999;} +.radio,.checkbox{min-height:20px;padding-left:20px;} +.radio input[type="radio"],.checkbox input[type="checkbox"]{float:left;margin-left:-20px;} +.controls>.radio:first-child,.controls>.checkbox:first-child{padding-top:5px;} +.radio.inline,.checkbox.inline{display:inline-block;padding-top:5px;margin-bottom:0;vertical-align:middle;} +.radio.inline+.radio.inline,.checkbox.inline+.checkbox.inline{margin-left:10px;} +.input-mini{width:60px;} +.input-small{width:90px;} +.input-medium{width:150px;} +.input-large{width:210px;} +.input-xlarge{width:270px;} +.input-xxlarge{width:530px;} +input[class*="span"],select[class*="span"],textarea[class*="span"],.uneditable-input[class*="span"],.row-fluid input[class*="span"],.row-fluid select[class*="span"],.row-fluid textarea[class*="span"],.row-fluid .uneditable-input[class*="span"]{float:none;margin-left:0;} +.input-append input[class*="span"],.input-append .uneditable-input[class*="span"],.input-prepend input[class*="span"],.input-prepend .uneditable-input[class*="span"],.row-fluid input[class*="span"],.row-fluid select[class*="span"],.row-fluid textarea[class*="span"],.row-fluid .uneditable-input[class*="span"],.row-fluid .input-prepend [class*="span"],.row-fluid .input-append [class*="span"]{display:inline-block;} +input,textarea,.uneditable-input{margin-left:0;} +.controls-row [class*="span"]+[class*="span"]{margin-left:20px;} +input.span12,textarea.span12,.uneditable-input.span12{width:926px;} +input.span11,textarea.span11,.uneditable-input.span11{width:846px;} +input.span10,textarea.span10,.uneditable-input.span10{width:766px;} +input.span9,textarea.span9,.uneditable-input.span9{width:686px;} +input.span8,textarea.span8,.uneditable-input.span8{width:606px;} +input.span7,textarea.span7,.uneditable-input.span7{width:526px;} +input.span6,textarea.span6,.uneditable-input.span6{width:446px;} +input.span5,textarea.span5,.uneditable-input.span5{width:366px;} +input.span4,textarea.span4,.uneditable-input.span4{width:286px;} +input.span3,textarea.span3,.uneditable-input.span3{width:206px;} +input.span2,textarea.span2,.uneditable-input.span2{width:126px;} +input.span1,textarea.span1,.uneditable-input.span1{width:46px;} +.controls-row{*zoom:1;}.controls-row:before,.controls-row:after{display:table;content:"";line-height:0;} +.controls-row:after{clear:both;} +.controls-row [class*="span"],.row-fluid .controls-row [class*="span"]{float:left;} +.controls-row .checkbox[class*="span"],.controls-row .radio[class*="span"]{padding-top:5px;} +input[disabled],select[disabled],textarea[disabled],input[readonly],select[readonly],textarea[readonly]{cursor:not-allowed;background-color:#eeeeee;} +input[type="radio"][disabled],input[type="checkbox"][disabled],input[type="radio"][readonly],input[type="checkbox"][readonly]{background-color:transparent;} +.control-group.warning .control-label,.control-group.warning .help-block,.control-group.warning .help-inline{color:#c09853;} +.control-group.warning .checkbox,.control-group.warning .radio,.control-group.warning input,.control-group.warning select,.control-group.warning textarea{color:#c09853;} +.control-group.warning input,.control-group.warning select,.control-group.warning textarea{border-color:#c09853;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);}.control-group.warning input:focus,.control-group.warning select:focus,.control-group.warning textarea:focus{border-color:#a47e3c;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #dbc59e;-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #dbc59e;box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #dbc59e;} +.control-group.warning .input-prepend .add-on,.control-group.warning .input-append .add-on{color:#c09853;background-color:#fcf8e3;border-color:#c09853;} +.control-group.error .control-label,.control-group.error .help-block,.control-group.error .help-inline{color:#b94a48;} +.control-group.error .checkbox,.control-group.error .radio,.control-group.error input,.control-group.error select,.control-group.error textarea{color:#b94a48;} +.control-group.error input,.control-group.error select,.control-group.error textarea{border-color:#b94a48;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);}.control-group.error input:focus,.control-group.error select:focus,.control-group.error textarea:focus{border-color:#953b39;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #d59392;-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #d59392;box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #d59392;} +.control-group.error .input-prepend .add-on,.control-group.error .input-append .add-on{color:#b94a48;background-color:#f2dede;border-color:#b94a48;} +.control-group.success .control-label,.control-group.success .help-block,.control-group.success .help-inline{color:#468847;} +.control-group.success .checkbox,.control-group.success .radio,.control-group.success input,.control-group.success select,.control-group.success textarea{color:#468847;} +.control-group.success input,.control-group.success select,.control-group.success textarea{border-color:#468847;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);}.control-group.success input:focus,.control-group.success select:focus,.control-group.success textarea:focus{border-color:#356635;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7aba7b;-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7aba7b;box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7aba7b;} +.control-group.success .input-prepend .add-on,.control-group.success .input-append .add-on{color:#468847;background-color:#dff0d8;border-color:#468847;} +.control-group.info .control-label,.control-group.info .help-block,.control-group.info .help-inline{color:#3a87ad;} +.control-group.info .checkbox,.control-group.info .radio,.control-group.info input,.control-group.info select,.control-group.info textarea{color:#3a87ad;} +.control-group.info input,.control-group.info select,.control-group.info textarea{border-color:#3a87ad;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);}.control-group.info input:focus,.control-group.info select:focus,.control-group.info textarea:focus{border-color:#2d6987;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7ab5d3;-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7ab5d3;box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7ab5d3;} +.control-group.info .input-prepend .add-on,.control-group.info .input-append .add-on{color:#3a87ad;background-color:#d9edf7;border-color:#3a87ad;} +input:focus:invalid,textarea:focus:invalid,select:focus:invalid{color:#b94a48;border-color:#ee5f5b;}input:focus:invalid:focus,textarea:focus:invalid:focus,select:focus:invalid:focus{border-color:#e9322d;-webkit-box-shadow:0 0 6px #f8b9b7;-moz-box-shadow:0 0 6px #f8b9b7;box-shadow:0 0 6px #f8b9b7;} +.form-actions{padding:19px 20px 20px;margin-top:20px;margin-bottom:20px;background-color:#f5f5f5;border-top:1px solid #e5e5e5;*zoom:1;}.form-actions:before,.form-actions:after{display:table;content:"";line-height:0;} +.form-actions:after{clear:both;} +.help-block,.help-inline{color:#595959;} +.help-block{display:block;margin-bottom:10px;} +.help-inline{display:inline-block;*display:inline;*zoom:1;vertical-align:middle;padding-left:5px;} +.input-append,.input-prepend{display:inline-block;margin-bottom:10px;vertical-align:middle;font-size:0;white-space:nowrap;}.input-append input,.input-prepend input,.input-append select,.input-prepend select,.input-append .uneditable-input,.input-prepend .uneditable-input,.input-append .dropdown-menu,.input-prepend .dropdown-menu,.input-append .popover,.input-prepend .popover{font-size:14px;} +.input-append input,.input-prepend input,.input-append select,.input-prepend select,.input-append .uneditable-input,.input-prepend .uneditable-input{position:relative;margin-bottom:0;*margin-left:0;vertical-align:top;-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;}.input-append input:focus,.input-prepend input:focus,.input-append select:focus,.input-prepend select:focus,.input-append .uneditable-input:focus,.input-prepend .uneditable-input:focus{z-index:2;} +.input-append .add-on,.input-prepend .add-on{display:inline-block;width:auto;height:20px;min-width:16px;padding:4px 5px;font-size:14px;font-weight:normal;line-height:20px;text-align:center;text-shadow:0 1px 0 #ffffff;background-color:#eeeeee;border:1px solid #ccc;} +.input-append .add-on,.input-prepend .add-on,.input-append .btn,.input-prepend .btn,.input-append .btn-group>.dropdown-toggle,.input-prepend .btn-group>.dropdown-toggle{vertical-align:top;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.input-append .active,.input-prepend .active{background-color:#a9dba9;border-color:#46a546;} +.input-prepend .add-on,.input-prepend .btn{margin-right:-1px;} +.input-prepend .add-on:first-child,.input-prepend .btn:first-child{-webkit-border-radius:4px 0 0 4px;-moz-border-radius:4px 0 0 4px;border-radius:4px 0 0 4px;} +.input-append input,.input-append select,.input-append .uneditable-input{-webkit-border-radius:4px 0 0 4px;-moz-border-radius:4px 0 0 4px;border-radius:4px 0 0 4px;}.input-append input+.btn-group .btn:last-child,.input-append select+.btn-group .btn:last-child,.input-append .uneditable-input+.btn-group .btn:last-child{-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;} +.input-append .add-on,.input-append .btn,.input-append .btn-group{margin-left:-1px;} +.input-append .add-on:last-child,.input-append .btn:last-child,.input-append .btn-group:last-child>.dropdown-toggle{-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;} +.input-prepend.input-append input,.input-prepend.input-append select,.input-prepend.input-append .uneditable-input{-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;}.input-prepend.input-append input+.btn-group .btn,.input-prepend.input-append select+.btn-group .btn,.input-prepend.input-append .uneditable-input+.btn-group .btn{-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;} +.input-prepend.input-append .add-on:first-child,.input-prepend.input-append .btn:first-child{margin-right:-1px;-webkit-border-radius:4px 0 0 4px;-moz-border-radius:4px 0 0 4px;border-radius:4px 0 0 4px;} +.input-prepend.input-append .add-on:last-child,.input-prepend.input-append .btn:last-child{margin-left:-1px;-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;} +.input-prepend.input-append .btn-group:first-child{margin-left:0;} +input.search-query{padding-right:14px;padding-right:4px \9;padding-left:14px;padding-left:4px \9;margin-bottom:0;-webkit-border-radius:15px;-moz-border-radius:15px;border-radius:15px;} +.form-search .input-append .search-query,.form-search .input-prepend .search-query{-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.form-search .input-append .search-query{-webkit-border-radius:14px 0 0 14px;-moz-border-radius:14px 0 0 14px;border-radius:14px 0 0 14px;} +.form-search .input-append .btn{-webkit-border-radius:0 14px 14px 0;-moz-border-radius:0 14px 14px 0;border-radius:0 14px 14px 0;} +.form-search .input-prepend .search-query{-webkit-border-radius:0 14px 14px 0;-moz-border-radius:0 14px 14px 0;border-radius:0 14px 14px 0;} +.form-search .input-prepend .btn{-webkit-border-radius:14px 0 0 14px;-moz-border-radius:14px 0 0 14px;border-radius:14px 0 0 14px;} +.form-search input,.form-inline input,.form-horizontal input,.form-search textarea,.form-inline textarea,.form-horizontal textarea,.form-search select,.form-inline select,.form-horizontal select,.form-search .help-inline,.form-inline .help-inline,.form-horizontal .help-inline,.form-search .uneditable-input,.form-inline .uneditable-input,.form-horizontal .uneditable-input,.form-search .input-prepend,.form-inline .input-prepend,.form-horizontal .input-prepend,.form-search .input-append,.form-inline .input-append,.form-horizontal .input-append{display:inline-block;*display:inline;*zoom:1;margin-bottom:0;vertical-align:middle;} +.form-search .hide,.form-inline .hide,.form-horizontal .hide{display:none;} +.form-search label,.form-inline label,.form-search .btn-group,.form-inline .btn-group{display:inline-block;} +.form-search .input-append,.form-inline .input-append,.form-search .input-prepend,.form-inline .input-prepend{margin-bottom:0;} +.form-search .radio,.form-search .checkbox,.form-inline .radio,.form-inline .checkbox{padding-left:0;margin-bottom:0;vertical-align:middle;} +.form-search .radio input[type="radio"],.form-search .checkbox input[type="checkbox"],.form-inline .radio input[type="radio"],.form-inline .checkbox input[type="checkbox"]{float:left;margin-right:3px;margin-left:0;} +.control-group{margin-bottom:10px;} +legend+.control-group{margin-top:20px;-webkit-margin-top-collapse:separate;} +.form-horizontal .control-group{margin-bottom:20px;*zoom:1;}.form-horizontal .control-group:before,.form-horizontal .control-group:after{display:table;content:"";line-height:0;} +.form-horizontal .control-group:after{clear:both;} +.form-horizontal .control-label{float:left;width:160px;padding-top:5px;text-align:right;} +.form-horizontal .controls{*display:inline-block;*padding-left:20px;margin-left:180px;*margin-left:0;}.form-horizontal .controls:first-child{*padding-left:180px;} +.form-horizontal .help-block{margin-bottom:0;} +.form-horizontal input+.help-block,.form-horizontal select+.help-block,.form-horizontal textarea+.help-block,.form-horizontal .uneditable-input+.help-block,.form-horizontal .input-prepend+.help-block,.form-horizontal .input-append+.help-block{margin-top:10px;} +.form-horizontal .form-actions{padding-left:180px;} +.btn{display:inline-block;*display:inline;*zoom:1;padding:4px 12px;margin-bottom:0;font-size:14px;line-height:20px;text-align:center;vertical-align:middle;cursor:pointer;color:#333333;text-shadow:0 1px 1px rgba(255, 255, 255, 0.75);background-color:#f5f5f5;background-image:-moz-linear-gradient(top, #ffffff, #e6e6e6);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#ffffff), to(#e6e6e6));background-image:-webkit-linear-gradient(top, #ffffff, #e6e6e6);background-image:-o-linear-gradient(top, #ffffff, #e6e6e6);background-image:linear-gradient(to bottom, #ffffff, #e6e6e6);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ffffffff', endColorstr='#ffe6e6e6', GradientType=0);border-color:#e6e6e6 #e6e6e6 #bfbfbf;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#e6e6e6;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);border:1px solid #cccccc;*border:0;border-bottom-color:#b3b3b3;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;*margin-left:.3em;-webkit-box-shadow:inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);-moz-box-shadow:inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);box-shadow:inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);}.btn:hover,.btn:focus,.btn:active,.btn.active,.btn.disabled,.btn[disabled]{color:#333333;background-color:#e6e6e6;*background-color:#d9d9d9;} +.btn:active,.btn.active{background-color:#cccccc \9;} +.btn:first-child{*margin-left:0;} +.btn:hover,.btn:focus{color:#333333;text-decoration:none;background-position:0 -15px;-webkit-transition:background-position 0.1s linear;-moz-transition:background-position 0.1s linear;-o-transition:background-position 0.1s linear;transition:background-position 0.1s linear;} +.btn:focus{outline:thin dotted #333;outline:5px auto -webkit-focus-ring-color;outline-offset:-2px;} +.btn.active,.btn:active{background-image:none;outline:0;-webkit-box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);-moz-box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);} +.btn.disabled,.btn[disabled]{cursor:default;background-image:none;opacity:0.65;filter:alpha(opacity=65);-webkit-box-shadow:none;-moz-box-shadow:none;box-shadow:none;} +.btn-large{padding:11px 19px;font-size:17.5px;-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;} +.btn-large [class^="icon-"],.btn-large [class*=" icon-"]{margin-top:4px;} +.btn-small{padding:2px 10px;font-size:11.9px;-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} +.btn-small [class^="icon-"],.btn-small [class*=" icon-"]{margin-top:0;} +.btn-mini [class^="icon-"],.btn-mini [class*=" icon-"]{margin-top:-1px;} +.btn-mini{padding:0 6px;font-size:10.5px;-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} +.btn-block{display:block;width:100%;padding-left:0;padding-right:0;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;} +.btn-block+.btn-block{margin-top:5px;} +input[type="submit"].btn-block,input[type="reset"].btn-block,input[type="button"].btn-block{width:100%;} +.btn-primary.active,.btn-warning.active,.btn-danger.active,.btn-success.active,.btn-info.active,.btn-inverse.active{color:rgba(255, 255, 255, 0.75);} +.btn-primary{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#006dcc;background-image:-moz-linear-gradient(top, #0088cc, #0044cc);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#0088cc), to(#0044cc));background-image:-webkit-linear-gradient(top, #0088cc, #0044cc);background-image:-o-linear-gradient(top, #0088cc, #0044cc);background-image:linear-gradient(to bottom, #0088cc, #0044cc);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff0088cc', endColorstr='#ff0044cc', GradientType=0);border-color:#0044cc #0044cc #002a80;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#0044cc;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-primary:hover,.btn-primary:focus,.btn-primary:active,.btn-primary.active,.btn-primary.disabled,.btn-primary[disabled]{color:#ffffff;background-color:#0044cc;*background-color:#003bb3;} +.btn-primary:active,.btn-primary.active{background-color:#003399 \9;} +.btn-warning{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#faa732;background-image:-moz-linear-gradient(top, #fbb450, #f89406);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#fbb450), to(#f89406));background-image:-webkit-linear-gradient(top, #fbb450, #f89406);background-image:-o-linear-gradient(top, #fbb450, #f89406);background-image:linear-gradient(to bottom, #fbb450, #f89406);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#fffbb450', endColorstr='#fff89406', GradientType=0);border-color:#f89406 #f89406 #ad6704;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#f89406;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-warning:hover,.btn-warning:focus,.btn-warning:active,.btn-warning.active,.btn-warning.disabled,.btn-warning[disabled]{color:#ffffff;background-color:#f89406;*background-color:#df8505;} +.btn-warning:active,.btn-warning.active{background-color:#c67605 \9;} +.btn-danger{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#da4f49;background-image:-moz-linear-gradient(top, #ee5f5b, #bd362f);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#ee5f5b), to(#bd362f));background-image:-webkit-linear-gradient(top, #ee5f5b, #bd362f);background-image:-o-linear-gradient(top, #ee5f5b, #bd362f);background-image:linear-gradient(to bottom, #ee5f5b, #bd362f);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ffee5f5b', endColorstr='#ffbd362f', GradientType=0);border-color:#bd362f #bd362f #802420;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#bd362f;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-danger:hover,.btn-danger:focus,.btn-danger:active,.btn-danger.active,.btn-danger.disabled,.btn-danger[disabled]{color:#ffffff;background-color:#bd362f;*background-color:#a9302a;} +.btn-danger:active,.btn-danger.active{background-color:#942a25 \9;} +.btn-success{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#5bb75b;background-image:-moz-linear-gradient(top, #62c462, #51a351);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#62c462), to(#51a351));background-image:-webkit-linear-gradient(top, #62c462, #51a351);background-image:-o-linear-gradient(top, #62c462, #51a351);background-image:linear-gradient(to bottom, #62c462, #51a351);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff62c462', endColorstr='#ff51a351', GradientType=0);border-color:#51a351 #51a351 #387038;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#51a351;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-success:hover,.btn-success:focus,.btn-success:active,.btn-success.active,.btn-success.disabled,.btn-success[disabled]{color:#ffffff;background-color:#51a351;*background-color:#499249;} +.btn-success:active,.btn-success.active{background-color:#408140 \9;} +.btn-info{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#49afcd;background-image:-moz-linear-gradient(top, #5bc0de, #2f96b4);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#5bc0de), to(#2f96b4));background-image:-webkit-linear-gradient(top, #5bc0de, #2f96b4);background-image:-o-linear-gradient(top, #5bc0de, #2f96b4);background-image:linear-gradient(to bottom, #5bc0de, #2f96b4);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff5bc0de', endColorstr='#ff2f96b4', GradientType=0);border-color:#2f96b4 #2f96b4 #1f6377;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#2f96b4;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-info:hover,.btn-info:focus,.btn-info:active,.btn-info.active,.btn-info.disabled,.btn-info[disabled]{color:#ffffff;background-color:#2f96b4;*background-color:#2a85a0;} +.btn-info:active,.btn-info.active{background-color:#24748c \9;} +.btn-inverse{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#363636;background-image:-moz-linear-gradient(top, #444444, #222222);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#444444), to(#222222));background-image:-webkit-linear-gradient(top, #444444, #222222);background-image:-o-linear-gradient(top, #444444, #222222);background-image:linear-gradient(to bottom, #444444, #222222);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff444444', endColorstr='#ff222222', GradientType=0);border-color:#222222 #222222 #000000;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#222222;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-inverse:hover,.btn-inverse:focus,.btn-inverse:active,.btn-inverse.active,.btn-inverse.disabled,.btn-inverse[disabled]{color:#ffffff;background-color:#222222;*background-color:#151515;} +.btn-inverse:active,.btn-inverse.active{background-color:#080808 \9;} +button.btn,input[type="submit"].btn{*padding-top:3px;*padding-bottom:3px;}button.btn::-moz-focus-inner,input[type="submit"].btn::-moz-focus-inner{padding:0;border:0;} +button.btn.btn-large,input[type="submit"].btn.btn-large{*padding-top:7px;*padding-bottom:7px;} +button.btn.btn-small,input[type="submit"].btn.btn-small{*padding-top:3px;*padding-bottom:3px;} +button.btn.btn-mini,input[type="submit"].btn.btn-mini{*padding-top:1px;*padding-bottom:1px;} +.btn-link,.btn-link:active,.btn-link[disabled]{background-color:transparent;background-image:none;-webkit-box-shadow:none;-moz-box-shadow:none;box-shadow:none;} +.btn-link{border-color:transparent;cursor:pointer;color:#0088cc;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.btn-link:hover,.btn-link:focus{color:#005580;text-decoration:underline;background-color:transparent;} +.btn-link[disabled]:hover,.btn-link[disabled]:focus{color:#333333;text-decoration:none;} +[class^="icon-"],[class*=" icon-"]{display:inline-block;width:14px;height:14px;*margin-right:.3em;line-height:14px;vertical-align:text-top;background-image:url("../img/glyphicons-halflings.png");background-position:14px 14px;background-repeat:no-repeat;margin-top:1px;} +.icon-white,.nav-pills>.active>a>[class^="icon-"],.nav-pills>.active>a>[class*=" icon-"],.nav-list>.active>a>[class^="icon-"],.nav-list>.active>a>[class*=" icon-"],.navbar-inverse .nav>.active>a>[class^="icon-"],.navbar-inverse .nav>.active>a>[class*=" icon-"],.dropdown-menu>li>a:hover>[class^="icon-"],.dropdown-menu>li>a:focus>[class^="icon-"],.dropdown-menu>li>a:hover>[class*=" icon-"],.dropdown-menu>li>a:focus>[class*=" icon-"],.dropdown-menu>.active>a>[class^="icon-"],.dropdown-menu>.active>a>[class*=" icon-"],.dropdown-submenu:hover>a>[class^="icon-"],.dropdown-submenu:focus>a>[class^="icon-"],.dropdown-submenu:hover>a>[class*=" icon-"],.dropdown-submenu:focus>a>[class*=" icon-"]{background-image:url("../img/glyphicons-halflings-white.png");} +.icon-glass{background-position:0 0;} +.icon-music{background-position:-24px 0;} +.icon-search{background-position:-48px 0;} +.icon-envelope{background-position:-72px 0;} +.icon-heart{background-position:-96px 0;} +.icon-star{background-position:-120px 0;} +.icon-star-empty{background-position:-144px 0;} +.icon-user{background-position:-168px 0;} +.icon-film{background-position:-192px 0;} +.icon-th-large{background-position:-216px 0;} +.icon-th{background-position:-240px 0;} +.icon-th-list{background-position:-264px 0;} +.icon-ok{background-position:-288px 0;} +.icon-remove{background-position:-312px 0;} +.icon-zoom-in{background-position:-336px 0;} +.icon-zoom-out{background-position:-360px 0;} +.icon-off{background-position:-384px 0;} +.icon-signal{background-position:-408px 0;} +.icon-cog{background-position:-432px 0;} +.icon-trash{background-position:-456px 0;} +.icon-home{background-position:0 -24px;} +.icon-file{background-position:-24px -24px;} +.icon-time{background-position:-48px -24px;} +.icon-road{background-position:-72px -24px;} +.icon-download-alt{background-position:-96px -24px;} +.icon-download{background-position:-120px -24px;} +.icon-upload{background-position:-144px -24px;} +.icon-inbox{background-position:-168px -24px;} +.icon-play-circle{background-position:-192px -24px;} +.icon-repeat{background-position:-216px -24px;} +.icon-refresh{background-position:-240px -24px;} +.icon-list-alt{background-position:-264px -24px;} +.icon-lock{background-position:-287px -24px;} +.icon-flag{background-position:-312px -24px;} +.icon-headphones{background-position:-336px -24px;} +.icon-volume-off{background-position:-360px -24px;} +.icon-volume-down{background-position:-384px -24px;} +.icon-volume-up{background-position:-408px -24px;} +.icon-qrcode{background-position:-432px -24px;} +.icon-barcode{background-position:-456px -24px;} +.icon-tag{background-position:0 -48px;} +.icon-tags{background-position:-25px -48px;} +.icon-book{background-position:-48px -48px;} +.icon-bookmark{background-position:-72px -48px;} +.icon-print{background-position:-96px -48px;} +.icon-camera{background-position:-120px -48px;} +.icon-font{background-position:-144px -48px;} +.icon-bold{background-position:-167px -48px;} +.icon-italic{background-position:-192px -48px;} +.icon-text-height{background-position:-216px -48px;} +.icon-text-width{background-position:-240px -48px;} +.icon-align-left{background-position:-264px -48px;} +.icon-align-center{background-position:-288px -48px;} +.icon-align-right{background-position:-312px -48px;} +.icon-align-justify{background-position:-336px -48px;} +.icon-list{background-position:-360px -48px;} +.icon-indent-left{background-position:-384px -48px;} +.icon-indent-right{background-position:-408px -48px;} +.icon-facetime-video{background-position:-432px -48px;} +.icon-picture{background-position:-456px -48px;} +.icon-pencil{background-position:0 -72px;} +.icon-map-marker{background-position:-24px -72px;} +.icon-adjust{background-position:-48px -72px;} +.icon-tint{background-position:-72px -72px;} +.icon-edit{background-position:-96px -72px;} +.icon-share{background-position:-120px -72px;} +.icon-check{background-position:-144px -72px;} +.icon-move{background-position:-168px -72px;} +.icon-step-backward{background-position:-192px -72px;} +.icon-fast-backward{background-position:-216px -72px;} +.icon-backward{background-position:-240px -72px;} +.icon-play{background-position:-264px -72px;} +.icon-pause{background-position:-288px -72px;} +.icon-stop{background-position:-312px -72px;} +.icon-forward{background-position:-336px -72px;} +.icon-fast-forward{background-position:-360px -72px;} +.icon-step-forward{background-position:-384px -72px;} +.icon-eject{background-position:-408px -72px;} +.icon-chevron-left{background-position:-432px -72px;} +.icon-chevron-right{background-position:-456px -72px;} +.icon-plus-sign{background-position:0 -96px;} +.icon-minus-sign{background-position:-24px -96px;} +.icon-remove-sign{background-position:-48px -96px;} +.icon-ok-sign{background-position:-72px -96px;} +.icon-question-sign{background-position:-96px -96px;} +.icon-info-sign{background-position:-120px -96px;} +.icon-screenshot{background-position:-144px -96px;} +.icon-remove-circle{background-position:-168px -96px;} +.icon-ok-circle{background-position:-192px -96px;} +.icon-ban-circle{background-position:-216px -96px;} +.icon-arrow-left{background-position:-240px -96px;} +.icon-arrow-right{background-position:-264px -96px;} +.icon-arrow-up{background-position:-289px -96px;} +.icon-arrow-down{background-position:-312px -96px;} +.icon-share-alt{background-position:-336px -96px;} +.icon-resize-full{background-position:-360px -96px;} +.icon-resize-small{background-position:-384px -96px;} +.icon-plus{background-position:-408px -96px;} +.icon-minus{background-position:-433px -96px;} +.icon-asterisk{background-position:-456px -96px;} +.icon-exclamation-sign{background-position:0 -120px;} +.icon-gift{background-position:-24px -120px;} +.icon-leaf{background-position:-48px -120px;} +.icon-fire{background-position:-72px -120px;} +.icon-eye-open{background-position:-96px -120px;} +.icon-eye-close{background-position:-120px -120px;} +.icon-warning-sign{background-position:-144px -120px;} +.icon-plane{background-position:-168px -120px;} +.icon-calendar{background-position:-192px -120px;} +.icon-random{background-position:-216px -120px;width:16px;} +.icon-comment{background-position:-240px -120px;} +.icon-magnet{background-position:-264px -120px;} +.icon-chevron-up{background-position:-288px -120px;} +.icon-chevron-down{background-position:-313px -119px;} +.icon-retweet{background-position:-336px -120px;} +.icon-shopping-cart{background-position:-360px -120px;} +.icon-folder-close{background-position:-384px -120px;width:16px;} +.icon-folder-open{background-position:-408px -120px;width:16px;} +.icon-resize-vertical{background-position:-432px -119px;} +.icon-resize-horizontal{background-position:-456px -118px;} +.icon-hdd{background-position:0 -144px;} +.icon-bullhorn{background-position:-24px -144px;} +.icon-bell{background-position:-48px -144px;} +.icon-certificate{background-position:-72px -144px;} +.icon-thumbs-up{background-position:-96px -144px;} +.icon-thumbs-down{background-position:-120px -144px;} +.icon-hand-right{background-position:-144px -144px;} +.icon-hand-left{background-position:-168px -144px;} +.icon-hand-up{background-position:-192px -144px;} +.icon-hand-down{background-position:-216px -144px;} +.icon-circle-arrow-right{background-position:-240px -144px;} +.icon-circle-arrow-left{background-position:-264px -144px;} +.icon-circle-arrow-up{background-position:-288px -144px;} +.icon-circle-arrow-down{background-position:-312px -144px;} +.icon-globe{background-position:-336px -144px;} +.icon-wrench{background-position:-360px -144px;} +.icon-tasks{background-position:-384px -144px;} +.icon-filter{background-position:-408px -144px;} +.icon-briefcase{background-position:-432px -144px;} +.icon-fullscreen{background-position:-456px -144px;} +.btn-group{position:relative;display:inline-block;*display:inline;*zoom:1;font-size:0;vertical-align:middle;white-space:nowrap;*margin-left:.3em;}.btn-group:first-child{*margin-left:0;} +.btn-group+.btn-group{margin-left:5px;} +.btn-toolbar{font-size:0;margin-top:10px;margin-bottom:10px;}.btn-toolbar>.btn+.btn,.btn-toolbar>.btn-group+.btn,.btn-toolbar>.btn+.btn-group{margin-left:5px;} +.btn-group>.btn{position:relative;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.btn-group>.btn+.btn{margin-left:-1px;} +.btn-group>.btn,.btn-group>.dropdown-menu,.btn-group>.popover{font-size:14px;} +.btn-group>.btn-mini{font-size:10.5px;} +.btn-group>.btn-small{font-size:11.9px;} +.btn-group>.btn-large{font-size:17.5px;} +.btn-group>.btn:first-child{margin-left:0;-webkit-border-top-left-radius:4px;-moz-border-radius-topleft:4px;border-top-left-radius:4px;-webkit-border-bottom-left-radius:4px;-moz-border-radius-bottomleft:4px;border-bottom-left-radius:4px;} +.btn-group>.btn:last-child,.btn-group>.dropdown-toggle{-webkit-border-top-right-radius:4px;-moz-border-radius-topright:4px;border-top-right-radius:4px;-webkit-border-bottom-right-radius:4px;-moz-border-radius-bottomright:4px;border-bottom-right-radius:4px;} +.btn-group>.btn.large:first-child{margin-left:0;-webkit-border-top-left-radius:6px;-moz-border-radius-topleft:6px;border-top-left-radius:6px;-webkit-border-bottom-left-radius:6px;-moz-border-radius-bottomleft:6px;border-bottom-left-radius:6px;} +.btn-group>.btn.large:last-child,.btn-group>.large.dropdown-toggle{-webkit-border-top-right-radius:6px;-moz-border-radius-topright:6px;border-top-right-radius:6px;-webkit-border-bottom-right-radius:6px;-moz-border-radius-bottomright:6px;border-bottom-right-radius:6px;} +.btn-group>.btn:hover,.btn-group>.btn:focus,.btn-group>.btn:active,.btn-group>.btn.active{z-index:2;} +.btn-group .dropdown-toggle:active,.btn-group.open .dropdown-toggle{outline:0;} +.btn-group>.btn+.dropdown-toggle{padding-left:8px;padding-right:8px;-webkit-box-shadow:inset 1px 0 0 rgba(255,255,255,.125), inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);-moz-box-shadow:inset 1px 0 0 rgba(255,255,255,.125), inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);box-shadow:inset 1px 0 0 rgba(255,255,255,.125), inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);*padding-top:5px;*padding-bottom:5px;} +.btn-group>.btn-mini+.dropdown-toggle{padding-left:5px;padding-right:5px;*padding-top:2px;*padding-bottom:2px;} +.btn-group>.btn-small+.dropdown-toggle{*padding-top:5px;*padding-bottom:4px;} +.btn-group>.btn-large+.dropdown-toggle{padding-left:12px;padding-right:12px;*padding-top:7px;*padding-bottom:7px;} +.btn-group.open .dropdown-toggle{background-image:none;-webkit-box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);-moz-box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);} +.btn-group.open .btn.dropdown-toggle{background-color:#e6e6e6;} +.btn-group.open .btn-primary.dropdown-toggle{background-color:#0044cc;} +.btn-group.open .btn-warning.dropdown-toggle{background-color:#f89406;} +.btn-group.open .btn-danger.dropdown-toggle{background-color:#bd362f;} +.btn-group.open .btn-success.dropdown-toggle{background-color:#51a351;} +.btn-group.open .btn-info.dropdown-toggle{background-color:#2f96b4;} +.btn-group.open .btn-inverse.dropdown-toggle{background-color:#222222;} +.btn .caret{margin-top:8px;margin-left:0;} +.btn-large .caret{margin-top:6px;} +.btn-large .caret{border-left-width:5px;border-right-width:5px;border-top-width:5px;} +.btn-mini .caret,.btn-small .caret{margin-top:8px;} +.dropup .btn-large .caret{border-bottom-width:5px;} +.btn-primary .caret,.btn-warning .caret,.btn-danger .caret,.btn-info .caret,.btn-success .caret,.btn-inverse .caret{border-top-color:#ffffff;border-bottom-color:#ffffff;} +.btn-group-vertical{display:inline-block;*display:inline;*zoom:1;} +.btn-group-vertical>.btn{display:block;float:none;max-width:100%;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.btn-group-vertical>.btn+.btn{margin-left:0;margin-top:-1px;} +.btn-group-vertical>.btn:first-child{-webkit-border-radius:4px 4px 0 0;-moz-border-radius:4px 4px 0 0;border-radius:4px 4px 0 0;} +.btn-group-vertical>.btn:last-child{-webkit-border-radius:0 0 4px 4px;-moz-border-radius:0 0 4px 4px;border-radius:0 0 4px 4px;} +.btn-group-vertical>.btn-large:first-child{-webkit-border-radius:6px 6px 0 0;-moz-border-radius:6px 6px 0 0;border-radius:6px 6px 0 0;} +.btn-group-vertical>.btn-large:last-child{-webkit-border-radius:0 0 6px 6px;-moz-border-radius:0 0 6px 6px;border-radius:0 0 6px 6px;} +.nav{margin-left:0;margin-bottom:20px;list-style:none;} +.nav>li>a{display:block;} +.nav>li>a:hover,.nav>li>a:focus{text-decoration:none;background-color:#eeeeee;} +.nav>li>a>img{max-width:none;} +.nav>.pull-right{float:right;} +.nav-header{display:block;padding:3px 15px;font-size:11px;font-weight:bold;line-height:20px;color:#999999;text-shadow:0 1px 0 rgba(255, 255, 255, 0.5);text-transform:uppercase;} +.nav li+.nav-header{margin-top:9px;} +.nav-list{padding-left:15px;padding-right:15px;margin-bottom:0;} +.nav-list>li>a,.nav-list .nav-header{margin-left:-15px;margin-right:-15px;text-shadow:0 1px 0 rgba(255, 255, 255, 0.5);} +.nav-list>li>a{padding:3px 15px;} +.nav-list>.active>a,.nav-list>.active>a:hover,.nav-list>.active>a:focus{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.2);background-color:#0088cc;} +.nav-list [class^="icon-"],.nav-list [class*=" icon-"]{margin-right:2px;} +.nav-list .divider{*width:100%;height:1px;margin:9px 1px;*margin:-5px 0 5px;overflow:hidden;background-color:#e5e5e5;border-bottom:1px solid #ffffff;} +.nav-tabs,.nav-pills{*zoom:1;}.nav-tabs:before,.nav-pills:before,.nav-tabs:after,.nav-pills:after{display:table;content:"";line-height:0;} +.nav-tabs:after,.nav-pills:after{clear:both;} +.nav-tabs>li,.nav-pills>li{float:left;} +.nav-tabs>li>a,.nav-pills>li>a{padding-right:12px;padding-left:12px;margin-right:2px;line-height:14px;} +.nav-tabs{border-bottom:1px solid #ddd;} +.nav-tabs>li{margin-bottom:-1px;} +.nav-tabs>li>a{padding-top:8px;padding-bottom:8px;line-height:20px;border:1px solid transparent;-webkit-border-radius:4px 4px 0 0;-moz-border-radius:4px 4px 0 0;border-radius:4px 4px 0 0;}.nav-tabs>li>a:hover,.nav-tabs>li>a:focus{border-color:#eeeeee #eeeeee #dddddd;} +.nav-tabs>.active>a,.nav-tabs>.active>a:hover,.nav-tabs>.active>a:focus{color:#555555;background-color:#ffffff;border:1px solid #ddd;border-bottom-color:transparent;cursor:default;} +.nav-pills>li>a{padding-top:8px;padding-bottom:8px;margin-top:2px;margin-bottom:2px;-webkit-border-radius:5px;-moz-border-radius:5px;border-radius:5px;} +.nav-pills>.active>a,.nav-pills>.active>a:hover,.nav-pills>.active>a:focus{color:#ffffff;background-color:#0088cc;} +.nav-stacked>li{float:none;} +.nav-stacked>li>a{margin-right:0;} +.nav-tabs.nav-stacked{border-bottom:0;} +.nav-tabs.nav-stacked>li>a{border:1px solid #ddd;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.nav-tabs.nav-stacked>li:first-child>a{-webkit-border-top-right-radius:4px;-moz-border-radius-topright:4px;border-top-right-radius:4px;-webkit-border-top-left-radius:4px;-moz-border-radius-topleft:4px;border-top-left-radius:4px;} +.nav-tabs.nav-stacked>li:last-child>a{-webkit-border-bottom-right-radius:4px;-moz-border-radius-bottomright:4px;border-bottom-right-radius:4px;-webkit-border-bottom-left-radius:4px;-moz-border-radius-bottomleft:4px;border-bottom-left-radius:4px;} +.nav-tabs.nav-stacked>li>a:hover,.nav-tabs.nav-stacked>li>a:focus{border-color:#ddd;z-index:2;} +.nav-pills.nav-stacked>li>a{margin-bottom:3px;} +.nav-pills.nav-stacked>li:last-child>a{margin-bottom:1px;} +.nav-tabs .dropdown-menu{-webkit-border-radius:0 0 6px 6px;-moz-border-radius:0 0 6px 6px;border-radius:0 0 6px 6px;} +.nav-pills .dropdown-menu{-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;} +.nav .dropdown-toggle .caret{border-top-color:#0088cc;border-bottom-color:#0088cc;margin-top:6px;} +.nav .dropdown-toggle:hover .caret,.nav .dropdown-toggle:focus .caret{border-top-color:#005580;border-bottom-color:#005580;} +.nav-tabs .dropdown-toggle .caret{margin-top:8px;} +.nav .active .dropdown-toggle .caret{border-top-color:#fff;border-bottom-color:#fff;} +.nav-tabs .active .dropdown-toggle .caret{border-top-color:#555555;border-bottom-color:#555555;} +.nav>.dropdown.active>a:hover,.nav>.dropdown.active>a:focus{cursor:pointer;} +.nav-tabs .open .dropdown-toggle,.nav-pills .open .dropdown-toggle,.nav>li.dropdown.open.active>a:hover,.nav>li.dropdown.open.active>a:focus{color:#ffffff;background-color:#999999;border-color:#999999;} +.nav li.dropdown.open .caret,.nav li.dropdown.open.active .caret,.nav li.dropdown.open a:hover .caret,.nav li.dropdown.open a:focus .caret{border-top-color:#ffffff;border-bottom-color:#ffffff;opacity:1;filter:alpha(opacity=100);} +.tabs-stacked .open>a:hover,.tabs-stacked .open>a:focus{border-color:#999999;} +.tabbable{*zoom:1;}.tabbable:before,.tabbable:after{display:table;content:"";line-height:0;} +.tabbable:after{clear:both;} +.tab-content{overflow:auto;} +.tabs-below>.nav-tabs,.tabs-right>.nav-tabs,.tabs-left>.nav-tabs{border-bottom:0;} +.tab-content>.tab-pane,.pill-content>.pill-pane{display:none;} +.tab-content>.active,.pill-content>.active{display:block;} +.tabs-below>.nav-tabs{border-top:1px solid #ddd;} +.tabs-below>.nav-tabs>li{margin-top:-1px;margin-bottom:0;} +.tabs-below>.nav-tabs>li>a{-webkit-border-radius:0 0 4px 4px;-moz-border-radius:0 0 4px 4px;border-radius:0 0 4px 4px;}.tabs-below>.nav-tabs>li>a:hover,.tabs-below>.nav-tabs>li>a:focus{border-bottom-color:transparent;border-top-color:#ddd;} +.tabs-below>.nav-tabs>.active>a,.tabs-below>.nav-tabs>.active>a:hover,.tabs-below>.nav-tabs>.active>a:focus{border-color:transparent #ddd #ddd #ddd;} +.tabs-left>.nav-tabs>li,.tabs-right>.nav-tabs>li{float:none;} +.tabs-left>.nav-tabs>li>a,.tabs-right>.nav-tabs>li>a{min-width:74px;margin-right:0;margin-bottom:3px;} +.tabs-left>.nav-tabs{float:left;margin-right:19px;border-right:1px solid #ddd;} +.tabs-left>.nav-tabs>li>a{margin-right:-1px;-webkit-border-radius:4px 0 0 4px;-moz-border-radius:4px 0 0 4px;border-radius:4px 0 0 4px;} +.tabs-left>.nav-tabs>li>a:hover,.tabs-left>.nav-tabs>li>a:focus{border-color:#eeeeee #dddddd #eeeeee #eeeeee;} +.tabs-left>.nav-tabs .active>a,.tabs-left>.nav-tabs .active>a:hover,.tabs-left>.nav-tabs .active>a:focus{border-color:#ddd transparent #ddd #ddd;*border-right-color:#ffffff;} +.tabs-right>.nav-tabs{float:right;margin-left:19px;border-left:1px solid #ddd;} +.tabs-right>.nav-tabs>li>a{margin-left:-1px;-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;} +.tabs-right>.nav-tabs>li>a:hover,.tabs-right>.nav-tabs>li>a:focus{border-color:#eeeeee #eeeeee #eeeeee #dddddd;} +.tabs-right>.nav-tabs .active>a,.tabs-right>.nav-tabs .active>a:hover,.tabs-right>.nav-tabs .active>a:focus{border-color:#ddd #ddd #ddd transparent;*border-left-color:#ffffff;} +.nav>.disabled>a{color:#999999;} +.nav>.disabled>a:hover,.nav>.disabled>a:focus{text-decoration:none;background-color:transparent;cursor:default;} +.navbar{overflow:visible;margin-bottom:20px;*position:relative;*z-index:2;} +.navbar-inner{min-height:40px;padding-left:20px;padding-right:20px;background-color:#fafafa;background-image:-moz-linear-gradient(top, #ffffff, #f2f2f2);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#ffffff), to(#f2f2f2));background-image:-webkit-linear-gradient(top, #ffffff, #f2f2f2);background-image:-o-linear-gradient(top, #ffffff, #f2f2f2);background-image:linear-gradient(to bottom, #ffffff, #f2f2f2);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ffffffff', endColorstr='#fff2f2f2', GradientType=0);border:1px solid #d4d4d4;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;-webkit-box-shadow:0 1px 4px rgba(0, 0, 0, 0.065);-moz-box-shadow:0 1px 4px rgba(0, 0, 0, 0.065);box-shadow:0 1px 4px rgba(0, 0, 0, 0.065);*zoom:1;}.navbar-inner:before,.navbar-inner:after{display:table;content:"";line-height:0;} +.navbar-inner:after{clear:both;} +.navbar .container{width:auto;} +.nav-collapse.collapse{height:auto;overflow:visible;} +.navbar .brand{float:left;display:block;padding:10px 20px 10px;margin-left:-20px;font-size:20px;font-weight:200;color:#777777;text-shadow:0 1px 0 #ffffff;}.navbar .brand:hover,.navbar .brand:focus{text-decoration:none;} +.navbar-text{margin-bottom:0;line-height:40px;color:#777777;} +.navbar-link{color:#777777;}.navbar-link:hover,.navbar-link:focus{color:#333333;} +.navbar .divider-vertical{height:40px;margin:0 9px;border-left:1px solid #f2f2f2;border-right:1px solid #ffffff;} +.navbar .btn,.navbar .btn-group{margin-top:5px;} +.navbar .btn-group .btn,.navbar .input-prepend .btn,.navbar .input-append .btn,.navbar .input-prepend .btn-group,.navbar .input-append .btn-group{margin-top:0;} +.navbar-form{margin-bottom:0;*zoom:1;}.navbar-form:before,.navbar-form:after{display:table;content:"";line-height:0;} +.navbar-form:after{clear:both;} +.navbar-form input,.navbar-form select,.navbar-form .radio,.navbar-form .checkbox{margin-top:5px;} +.navbar-form input,.navbar-form select,.navbar-form .btn{display:inline-block;margin-bottom:0;} +.navbar-form input[type="image"],.navbar-form input[type="checkbox"],.navbar-form input[type="radio"]{margin-top:3px;} +.navbar-form .input-append,.navbar-form .input-prepend{margin-top:5px;white-space:nowrap;}.navbar-form .input-append input,.navbar-form .input-prepend input{margin-top:0;} +.navbar-search{position:relative;float:left;margin-top:5px;margin-bottom:0;}.navbar-search .search-query{margin-bottom:0;padding:4px 14px;font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;font-size:13px;font-weight:normal;line-height:1;-webkit-border-radius:15px;-moz-border-radius:15px;border-radius:15px;} +.navbar-static-top{position:static;margin-bottom:0;}.navbar-static-top .navbar-inner{-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.navbar-fixed-top,.navbar-fixed-bottom{position:fixed;right:0;left:0;z-index:1030;margin-bottom:0;} +.navbar-fixed-top .navbar-inner,.navbar-static-top .navbar-inner{border-width:0 0 1px;} +.navbar-fixed-bottom .navbar-inner{border-width:1px 0 0;} +.navbar-fixed-top .navbar-inner,.navbar-fixed-bottom .navbar-inner{padding-left:0;padding-right:0;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.navbar-static-top .container,.navbar-fixed-top .container,.navbar-fixed-bottom .container{width:940px;} +.navbar-fixed-top{top:0;} +.navbar-fixed-top .navbar-inner,.navbar-static-top .navbar-inner{-webkit-box-shadow:0 1px 10px rgba(0,0,0,.1);-moz-box-shadow:0 1px 10px rgba(0,0,0,.1);box-shadow:0 1px 10px rgba(0,0,0,.1);} +.navbar-fixed-bottom{bottom:0;}.navbar-fixed-bottom .navbar-inner{-webkit-box-shadow:0 -1px 10px rgba(0,0,0,.1);-moz-box-shadow:0 -1px 10px rgba(0,0,0,.1);box-shadow:0 -1px 10px rgba(0,0,0,.1);} +.navbar .nav{position:relative;left:0;display:block;float:left;margin:0 10px 0 0;} +.navbar .nav.pull-right{float:right;margin-right:0;} +.navbar .nav>li{float:left;} +.navbar .nav>li>a{float:none;padding:10px 15px 10px;color:#777777;text-decoration:none;text-shadow:0 1px 0 #ffffff;} +.navbar .nav .dropdown-toggle .caret{margin-top:8px;} +.navbar .nav>li>a:focus,.navbar .nav>li>a:hover{background-color:transparent;color:#333333;text-decoration:none;} +.navbar .nav>.active>a,.navbar .nav>.active>a:hover,.navbar .nav>.active>a:focus{color:#555555;text-decoration:none;background-color:#e5e5e5;-webkit-box-shadow:inset 0 3px 8px rgba(0, 0, 0, 0.125);-moz-box-shadow:inset 0 3px 8px rgba(0, 0, 0, 0.125);box-shadow:inset 0 3px 8px rgba(0, 0, 0, 0.125);} +.navbar .btn-navbar{display:none;float:right;padding:7px 10px;margin-left:5px;margin-right:5px;color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#ededed;background-image:-moz-linear-gradient(top, #f2f2f2, #e5e5e5);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#f2f2f2), to(#e5e5e5));background-image:-webkit-linear-gradient(top, #f2f2f2, #e5e5e5);background-image:-o-linear-gradient(top, #f2f2f2, #e5e5e5);background-image:linear-gradient(to bottom, #f2f2f2, #e5e5e5);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#fff2f2f2', endColorstr='#ffe5e5e5', GradientType=0);border-color:#e5e5e5 #e5e5e5 #bfbfbf;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#e5e5e5;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);-webkit-box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.075);-moz-box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.075);box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.075);}.navbar .btn-navbar:hover,.navbar .btn-navbar:focus,.navbar .btn-navbar:active,.navbar .btn-navbar.active,.navbar .btn-navbar.disabled,.navbar .btn-navbar[disabled]{color:#ffffff;background-color:#e5e5e5;*background-color:#d9d9d9;} +.navbar .btn-navbar:active,.navbar .btn-navbar.active{background-color:#cccccc \9;} +.navbar .btn-navbar .icon-bar{display:block;width:18px;height:2px;background-color:#f5f5f5;-webkit-border-radius:1px;-moz-border-radius:1px;border-radius:1px;-webkit-box-shadow:0 1px 0 rgba(0, 0, 0, 0.25);-moz-box-shadow:0 1px 0 rgba(0, 0, 0, 0.25);box-shadow:0 1px 0 rgba(0, 0, 0, 0.25);} +.btn-navbar .icon-bar+.icon-bar{margin-top:3px;} +.navbar .nav>li>.dropdown-menu:before{content:'';display:inline-block;border-left:7px solid transparent;border-right:7px solid transparent;border-bottom:7px solid #ccc;border-bottom-color:rgba(0, 0, 0, 0.2);position:absolute;top:-7px;left:9px;} +.navbar .nav>li>.dropdown-menu:after{content:'';display:inline-block;border-left:6px solid transparent;border-right:6px solid transparent;border-bottom:6px solid #ffffff;position:absolute;top:-6px;left:10px;} +.navbar-fixed-bottom .nav>li>.dropdown-menu:before{border-top:7px solid #ccc;border-top-color:rgba(0, 0, 0, 0.2);border-bottom:0;bottom:-7px;top:auto;} +.navbar-fixed-bottom .nav>li>.dropdown-menu:after{border-top:6px solid #ffffff;border-bottom:0;bottom:-6px;top:auto;} +.navbar .nav li.dropdown>a:hover .caret,.navbar .nav li.dropdown>a:focus .caret{border-top-color:#333333;border-bottom-color:#333333;} +.navbar .nav li.dropdown.open>.dropdown-toggle,.navbar .nav li.dropdown.active>.dropdown-toggle,.navbar .nav li.dropdown.open.active>.dropdown-toggle{background-color:#e5e5e5;color:#555555;} +.navbar .nav li.dropdown>.dropdown-toggle .caret{border-top-color:#777777;border-bottom-color:#777777;} +.navbar .nav li.dropdown.open>.dropdown-toggle .caret,.navbar .nav li.dropdown.active>.dropdown-toggle .caret,.navbar .nav li.dropdown.open.active>.dropdown-toggle .caret{border-top-color:#555555;border-bottom-color:#555555;} +.navbar .pull-right>li>.dropdown-menu,.navbar .nav>li>.dropdown-menu.pull-right{left:auto;right:0;}.navbar .pull-right>li>.dropdown-menu:before,.navbar .nav>li>.dropdown-menu.pull-right:before{left:auto;right:12px;} +.navbar .pull-right>li>.dropdown-menu:after,.navbar .nav>li>.dropdown-menu.pull-right:after{left:auto;right:13px;} +.navbar .pull-right>li>.dropdown-menu .dropdown-menu,.navbar .nav>li>.dropdown-menu.pull-right .dropdown-menu{left:auto;right:100%;margin-left:0;margin-right:-1px;-webkit-border-radius:6px 0 6px 6px;-moz-border-radius:6px 0 6px 6px;border-radius:6px 0 6px 6px;} +.navbar-inverse .navbar-inner{background-color:#1b1b1b;background-image:-moz-linear-gradient(top, #222222, #111111);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#222222), to(#111111));background-image:-webkit-linear-gradient(top, #222222, #111111);background-image:-o-linear-gradient(top, #222222, #111111);background-image:linear-gradient(to bottom, #222222, #111111);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff222222', endColorstr='#ff111111', GradientType=0);border-color:#252525;} +.navbar-inverse .brand,.navbar-inverse .nav>li>a{color:#999999;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);}.navbar-inverse .brand:hover,.navbar-inverse .nav>li>a:hover,.navbar-inverse .brand:focus,.navbar-inverse .nav>li>a:focus{color:#ffffff;} +.navbar-inverse .brand{color:#999999;} +.navbar-inverse .navbar-text{color:#999999;} +.navbar-inverse .nav>li>a:focus,.navbar-inverse .nav>li>a:hover{background-color:transparent;color:#ffffff;} +.navbar-inverse .nav .active>a,.navbar-inverse .nav .active>a:hover,.navbar-inverse .nav .active>a:focus{color:#ffffff;background-color:#111111;} +.navbar-inverse .navbar-link{color:#999999;}.navbar-inverse .navbar-link:hover,.navbar-inverse .navbar-link:focus{color:#ffffff;} +.navbar-inverse .divider-vertical{border-left-color:#111111;border-right-color:#222222;} +.navbar-inverse .nav li.dropdown.open>.dropdown-toggle,.navbar-inverse .nav li.dropdown.active>.dropdown-toggle,.navbar-inverse .nav li.dropdown.open.active>.dropdown-toggle{background-color:#111111;color:#ffffff;} +.navbar-inverse .nav li.dropdown>a:hover .caret,.navbar-inverse .nav li.dropdown>a:focus .caret{border-top-color:#ffffff;border-bottom-color:#ffffff;} +.navbar-inverse .nav li.dropdown>.dropdown-toggle .caret{border-top-color:#999999;border-bottom-color:#999999;} +.navbar-inverse .nav li.dropdown.open>.dropdown-toggle .caret,.navbar-inverse .nav li.dropdown.active>.dropdown-toggle .caret,.navbar-inverse .nav li.dropdown.open.active>.dropdown-toggle .caret{border-top-color:#ffffff;border-bottom-color:#ffffff;} +.navbar-inverse .navbar-search .search-query{color:#ffffff;background-color:#515151;border-color:#111111;-webkit-box-shadow:inset 0 1px 2px rgba(0,0,0,.1), 0 1px 0 rgba(255,255,255,.15);-moz-box-shadow:inset 0 1px 2px rgba(0,0,0,.1), 0 1px 0 rgba(255,255,255,.15);box-shadow:inset 0 1px 2px rgba(0,0,0,.1), 0 1px 0 rgba(255,255,255,.15);-webkit-transition:none;-moz-transition:none;-o-transition:none;transition:none;}.navbar-inverse .navbar-search .search-query:-moz-placeholder{color:#cccccc;} +.navbar-inverse .navbar-search .search-query:-ms-input-placeholder{color:#cccccc;} +.navbar-inverse .navbar-search .search-query::-webkit-input-placeholder{color:#cccccc;} +.navbar-inverse .navbar-search .search-query:focus,.navbar-inverse .navbar-search .search-query.focused{padding:5px 15px;color:#333333;text-shadow:0 1px 0 #ffffff;background-color:#ffffff;border:0;-webkit-box-shadow:0 0 3px rgba(0, 0, 0, 0.15);-moz-box-shadow:0 0 3px rgba(0, 0, 0, 0.15);box-shadow:0 0 3px rgba(0, 0, 0, 0.15);outline:0;} +.navbar-inverse .btn-navbar{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#0e0e0e;background-image:-moz-linear-gradient(top, #151515, #040404);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#151515), to(#040404));background-image:-webkit-linear-gradient(top, #151515, #040404);background-image:-o-linear-gradient(top, #151515, #040404);background-image:linear-gradient(to bottom, #151515, #040404);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff151515', endColorstr='#ff040404', GradientType=0);border-color:#040404 #040404 #000000;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#040404;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.navbar-inverse .btn-navbar:hover,.navbar-inverse .btn-navbar:focus,.navbar-inverse .btn-navbar:active,.navbar-inverse .btn-navbar.active,.navbar-inverse .btn-navbar.disabled,.navbar-inverse .btn-navbar[disabled]{color:#ffffff;background-color:#040404;*background-color:#000000;} +.navbar-inverse .btn-navbar:active,.navbar-inverse .btn-navbar.active{background-color:#000000 \9;} +.breadcrumb{padding:8px 15px;margin:0 0 20px;list-style:none;background-color:#f5f5f5;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;}.breadcrumb>li{display:inline-block;*display:inline;*zoom:1;text-shadow:0 1px 0 #ffffff;}.breadcrumb>li>.divider{padding:0 5px;color:#ccc;} +.breadcrumb>.active{color:#999999;} +.pagination{margin:20px 0;} +.pagination ul{display:inline-block;*display:inline;*zoom:1;margin-left:0;margin-bottom:0;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;-webkit-box-shadow:0 1px 2px rgba(0, 0, 0, 0.05);-moz-box-shadow:0 1px 2px rgba(0, 0, 0, 0.05);box-shadow:0 1px 2px rgba(0, 0, 0, 0.05);} +.pagination ul>li{display:inline;} +.pagination ul>li>a,.pagination ul>li>span{float:left;padding:4px 12px;line-height:20px;text-decoration:none;background-color:#ffffff;border:1px solid #dddddd;border-left-width:0;} +.pagination ul>li>a:hover,.pagination ul>li>a:focus,.pagination ul>.active>a,.pagination ul>.active>span{background-color:#f5f5f5;} +.pagination ul>.active>a,.pagination ul>.active>span{color:#999999;cursor:default;} +.pagination ul>.disabled>span,.pagination ul>.disabled>a,.pagination ul>.disabled>a:hover,.pagination ul>.disabled>a:focus{color:#999999;background-color:transparent;cursor:default;} +.pagination ul>li:first-child>a,.pagination ul>li:first-child>span{border-left-width:1px;-webkit-border-top-left-radius:4px;-moz-border-radius-topleft:4px;border-top-left-radius:4px;-webkit-border-bottom-left-radius:4px;-moz-border-radius-bottomleft:4px;border-bottom-left-radius:4px;} +.pagination ul>li:last-child>a,.pagination ul>li:last-child>span{-webkit-border-top-right-radius:4px;-moz-border-radius-topright:4px;border-top-right-radius:4px;-webkit-border-bottom-right-radius:4px;-moz-border-radius-bottomright:4px;border-bottom-right-radius:4px;} +.pagination-centered{text-align:center;} +.pagination-right{text-align:right;} +.pagination-large ul>li>a,.pagination-large ul>li>span{padding:11px 19px;font-size:17.5px;} +.pagination-large ul>li:first-child>a,.pagination-large ul>li:first-child>span{-webkit-border-top-left-radius:6px;-moz-border-radius-topleft:6px;border-top-left-radius:6px;-webkit-border-bottom-left-radius:6px;-moz-border-radius-bottomleft:6px;border-bottom-left-radius:6px;} +.pagination-large ul>li:last-child>a,.pagination-large ul>li:last-child>span{-webkit-border-top-right-radius:6px;-moz-border-radius-topright:6px;border-top-right-radius:6px;-webkit-border-bottom-right-radius:6px;-moz-border-radius-bottomright:6px;border-bottom-right-radius:6px;} +.pagination-mini ul>li:first-child>a,.pagination-small ul>li:first-child>a,.pagination-mini ul>li:first-child>span,.pagination-small ul>li:first-child>span{-webkit-border-top-left-radius:3px;-moz-border-radius-topleft:3px;border-top-left-radius:3px;-webkit-border-bottom-left-radius:3px;-moz-border-radius-bottomleft:3px;border-bottom-left-radius:3px;} +.pagination-mini ul>li:last-child>a,.pagination-small ul>li:last-child>a,.pagination-mini ul>li:last-child>span,.pagination-small ul>li:last-child>span{-webkit-border-top-right-radius:3px;-moz-border-radius-topright:3px;border-top-right-radius:3px;-webkit-border-bottom-right-radius:3px;-moz-border-radius-bottomright:3px;border-bottom-right-radius:3px;} +.pagination-small ul>li>a,.pagination-small ul>li>span{padding:2px 10px;font-size:11.9px;} +.pagination-mini ul>li>a,.pagination-mini ul>li>span{padding:0 6px;font-size:10.5px;} +.pager{margin:20px 0;list-style:none;text-align:center;*zoom:1;}.pager:before,.pager:after{display:table;content:"";line-height:0;} +.pager:after{clear:both;} +.pager li{display:inline;} +.pager li>a,.pager li>span{display:inline-block;padding:5px 14px;background-color:#fff;border:1px solid #ddd;-webkit-border-radius:15px;-moz-border-radius:15px;border-radius:15px;} +.pager li>a:hover,.pager li>a:focus{text-decoration:none;background-color:#f5f5f5;} +.pager .next>a,.pager .next>span{float:right;} +.pager .previous>a,.pager .previous>span{float:left;} +.pager .disabled>a,.pager .disabled>a:hover,.pager .disabled>a:focus,.pager .disabled>span{color:#999999;background-color:#fff;cursor:default;} +.thumbnails{margin-left:-20px;list-style:none;*zoom:1;}.thumbnails:before,.thumbnails:after{display:table;content:"";line-height:0;} +.thumbnails:after{clear:both;} +.row-fluid .thumbnails{margin-left:0;} +.thumbnails>li{float:left;margin-bottom:20px;margin-left:20px;} +.thumbnail{display:block;padding:4px;line-height:20px;border:1px solid #ddd;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;-webkit-box-shadow:0 1px 3px rgba(0, 0, 0, 0.055);-moz-box-shadow:0 1px 3px rgba(0, 0, 0, 0.055);box-shadow:0 1px 3px rgba(0, 0, 0, 0.055);-webkit-transition:all 0.2s ease-in-out;-moz-transition:all 0.2s ease-in-out;-o-transition:all 0.2s ease-in-out;transition:all 0.2s ease-in-out;} +a.thumbnail:hover,a.thumbnail:focus{border-color:#0088cc;-webkit-box-shadow:0 1px 4px rgba(0, 105, 214, 0.25);-moz-box-shadow:0 1px 4px rgba(0, 105, 214, 0.25);box-shadow:0 1px 4px rgba(0, 105, 214, 0.25);} +.thumbnail>img{display:block;max-width:100%;margin-left:auto;margin-right:auto;} +.thumbnail .caption{padding:9px;color:#555555;} +.alert{padding:8px 35px 8px 14px;margin-bottom:20px;text-shadow:0 1px 0 rgba(255, 255, 255, 0.5);background-color:#fcf8e3;border:1px solid #fbeed5;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} +.alert,.alert h4{color:#c09853;} +.alert h4{margin:0;} +.alert .close{position:relative;top:-2px;right:-21px;line-height:20px;} +.alert-success{background-color:#dff0d8;border-color:#d6e9c6;color:#468847;} +.alert-success h4{color:#468847;} +.alert-danger,.alert-error{background-color:#f2dede;border-color:#eed3d7;color:#b94a48;} +.alert-danger h4,.alert-error h4{color:#b94a48;} +.alert-info{background-color:#d9edf7;border-color:#bce8f1;color:#3a87ad;} +.alert-info h4{color:#3a87ad;} +.alert-block{padding-top:14px;padding-bottom:14px;} +.alert-block>p,.alert-block>ul{margin-bottom:0;} +.alert-block p+p{margin-top:5px;} +@-webkit-keyframes progress-bar-stripes{from{background-position:40px 0;} to{background-position:0 0;}}@-moz-keyframes progress-bar-stripes{from{background-position:40px 0;} to{background-position:0 0;}}@-ms-keyframes progress-bar-stripes{from{background-position:40px 0;} to{background-position:0 0;}}@-o-keyframes progress-bar-stripes{from{background-position:0 0;} to{background-position:40px 0;}}@keyframes progress-bar-stripes{from{background-position:40px 0;} to{background-position:0 0;}}.progress{overflow:hidden;height:20px;margin-bottom:20px;background-color:#f7f7f7;background-image:-moz-linear-gradient(top, #f5f5f5, #f9f9f9);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#f5f5f5), to(#f9f9f9));background-image:-webkit-linear-gradient(top, #f5f5f5, #f9f9f9);background-image:-o-linear-gradient(top, #f5f5f5, #f9f9f9);background-image:linear-gradient(to bottom, #f5f5f5, #f9f9f9);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#fff5f5f5', endColorstr='#fff9f9f9', GradientType=0);-webkit-box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.1);-moz-box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.1);box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.1);-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} +.progress .bar{width:0%;height:100%;color:#ffffff;float:left;font-size:12px;text-align:center;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#0e90d2;background-image:-moz-linear-gradient(top, #149bdf, #0480be);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#149bdf), to(#0480be));background-image:-webkit-linear-gradient(top, #149bdf, #0480be);background-image:-o-linear-gradient(top, #149bdf, #0480be);background-image:linear-gradient(to bottom, #149bdf, #0480be);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff149bdf', endColorstr='#ff0480be', GradientType=0);-webkit-box-shadow:inset 0 -1px 0 rgba(0, 0, 0, 0.15);-moz-box-shadow:inset 0 -1px 0 rgba(0, 0, 0, 0.15);box-shadow:inset 0 -1px 0 rgba(0, 0, 0, 0.15);-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;-webkit-transition:width 0.6s ease;-moz-transition:width 0.6s ease;-o-transition:width 0.6s ease;transition:width 0.6s ease;} +.progress .bar+.bar{-webkit-box-shadow:inset 1px 0 0 rgba(0,0,0,.15), inset 0 -1px 0 rgba(0,0,0,.15);-moz-box-shadow:inset 1px 0 0 rgba(0,0,0,.15), inset 0 -1px 0 rgba(0,0,0,.15);box-shadow:inset 1px 0 0 rgba(0,0,0,.15), inset 0 -1px 0 rgba(0,0,0,.15);} +.progress-striped .bar{background-color:#149bdf;background-image:-webkit-gradient(linear, 0 100%, 100% 0, color-stop(0.25, rgba(255, 255, 255, 0.15)), color-stop(0.25, transparent), color-stop(0.5, transparent), color-stop(0.5, rgba(255, 255, 255, 0.15)), color-stop(0.75, rgba(255, 255, 255, 0.15)), color-stop(0.75, transparent), to(transparent));background-image:-webkit-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-moz-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-o-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);-webkit-background-size:40px 40px;-moz-background-size:40px 40px;-o-background-size:40px 40px;background-size:40px 40px;} +.progress.active .bar{-webkit-animation:progress-bar-stripes 2s linear infinite;-moz-animation:progress-bar-stripes 2s linear infinite;-ms-animation:progress-bar-stripes 2s linear infinite;-o-animation:progress-bar-stripes 2s linear infinite;animation:progress-bar-stripes 2s linear infinite;} +.progress-danger .bar,.progress .bar-danger{background-color:#dd514c;background-image:-moz-linear-gradient(top, #ee5f5b, #c43c35);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#ee5f5b), to(#c43c35));background-image:-webkit-linear-gradient(top, #ee5f5b, #c43c35);background-image:-o-linear-gradient(top, #ee5f5b, #c43c35);background-image:linear-gradient(to bottom, #ee5f5b, #c43c35);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ffee5f5b', endColorstr='#ffc43c35', GradientType=0);} +.progress-danger.progress-striped .bar,.progress-striped .bar-danger{background-color:#ee5f5b;background-image:-webkit-gradient(linear, 0 100%, 100% 0, color-stop(0.25, rgba(255, 255, 255, 0.15)), color-stop(0.25, transparent), color-stop(0.5, transparent), color-stop(0.5, rgba(255, 255, 255, 0.15)), color-stop(0.75, rgba(255, 255, 255, 0.15)), color-stop(0.75, transparent), to(transparent));background-image:-webkit-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-moz-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-o-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);} +.progress-success .bar,.progress .bar-success{background-color:#5eb95e;background-image:-moz-linear-gradient(top, #62c462, #57a957);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#62c462), to(#57a957));background-image:-webkit-linear-gradient(top, #62c462, #57a957);background-image:-o-linear-gradient(top, #62c462, #57a957);background-image:linear-gradient(to bottom, #62c462, #57a957);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff62c462', endColorstr='#ff57a957', GradientType=0);} +.progress-success.progress-striped .bar,.progress-striped .bar-success{background-color:#62c462;background-image:-webkit-gradient(linear, 0 100%, 100% 0, color-stop(0.25, rgba(255, 255, 255, 0.15)), color-stop(0.25, transparent), color-stop(0.5, transparent), color-stop(0.5, rgba(255, 255, 255, 0.15)), color-stop(0.75, rgba(255, 255, 255, 0.15)), color-stop(0.75, transparent), to(transparent));background-image:-webkit-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-moz-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-o-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);} +.progress-info .bar,.progress .bar-info{background-color:#4bb1cf;background-image:-moz-linear-gradient(top, #5bc0de, #339bb9);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#5bc0de), to(#339bb9));background-image:-webkit-linear-gradient(top, #5bc0de, #339bb9);background-image:-o-linear-gradient(top, #5bc0de, #339bb9);background-image:linear-gradient(to bottom, #5bc0de, #339bb9);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff5bc0de', endColorstr='#ff339bb9', GradientType=0);} +.progress-info.progress-striped .bar,.progress-striped .bar-info{background-color:#5bc0de;background-image:-webkit-gradient(linear, 0 100%, 100% 0, color-stop(0.25, rgba(255, 255, 255, 0.15)), color-stop(0.25, transparent), color-stop(0.5, transparent), color-stop(0.5, rgba(255, 255, 255, 0.15)), color-stop(0.75, rgba(255, 255, 255, 0.15)), color-stop(0.75, transparent), to(transparent));background-image:-webkit-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-moz-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-o-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);} +.progress-warning .bar,.progress .bar-warning{background-color:#faa732;background-image:-moz-linear-gradient(top, #fbb450, #f89406);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#fbb450), to(#f89406));background-image:-webkit-linear-gradient(top, #fbb450, #f89406);background-image:-o-linear-gradient(top, #fbb450, #f89406);background-image:linear-gradient(to bottom, #fbb450, #f89406);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#fffbb450', endColorstr='#fff89406', GradientType=0);} +.progress-warning.progress-striped .bar,.progress-striped .bar-warning{background-color:#fbb450;background-image:-webkit-gradient(linear, 0 100%, 100% 0, color-stop(0.25, rgba(255, 255, 255, 0.15)), color-stop(0.25, transparent), color-stop(0.5, transparent), color-stop(0.5, rgba(255, 255, 255, 0.15)), color-stop(0.75, rgba(255, 255, 255, 0.15)), color-stop(0.75, transparent), to(transparent));background-image:-webkit-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-moz-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-o-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);} +.hero-unit{padding:60px;margin-bottom:30px;font-size:18px;font-weight:200;line-height:30px;color:inherit;background-color:#eeeeee;-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;}.hero-unit h1{margin-bottom:0;font-size:60px;line-height:1;color:inherit;letter-spacing:-1px;} +.hero-unit li{line-height:30px;} +.media,.media-body{overflow:hidden;*overflow:visible;zoom:1;} +.media,.media .media{margin-top:15px;} +.media:first-child{margin-top:0;} +.media-object{display:block;} +.media-heading{margin:0 0 5px;} +.media>.pull-left{margin-right:10px;} +.media>.pull-right{margin-left:10px;} +.media-list{margin-left:0;list-style:none;} +.tooltip{position:absolute;z-index:1030;display:block;visibility:visible;font-size:11px;line-height:1.4;opacity:0;filter:alpha(opacity=0);}.tooltip.in{opacity:0.8;filter:alpha(opacity=80);} +.tooltip.top{margin-top:-3px;padding:5px 0;} +.tooltip.right{margin-left:3px;padding:0 5px;} +.tooltip.bottom{margin-top:3px;padding:5px 0;} +.tooltip.left{margin-left:-3px;padding:0 5px;} +.tooltip-inner{max-width:200px;padding:8px;color:#ffffff;text-align:center;text-decoration:none;background-color:#000000;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} +.tooltip-arrow{position:absolute;width:0;height:0;border-color:transparent;border-style:solid;} +.tooltip.top .tooltip-arrow{bottom:0;left:50%;margin-left:-5px;border-width:5px 5px 0;border-top-color:#000000;} +.tooltip.right .tooltip-arrow{top:50%;left:0;margin-top:-5px;border-width:5px 5px 5px 0;border-right-color:#000000;} +.tooltip.left .tooltip-arrow{top:50%;right:0;margin-top:-5px;border-width:5px 0 5px 5px;border-left-color:#000000;} +.tooltip.bottom .tooltip-arrow{top:0;left:50%;margin-left:-5px;border-width:0 5px 5px;border-bottom-color:#000000;} +.popover{position:absolute;top:0;left:0;z-index:1010;display:none;max-width:276px;padding:1px;text-align:left;background-color:#ffffff;-webkit-background-clip:padding-box;-moz-background-clip:padding;background-clip:padding-box;border:1px solid #ccc;border:1px solid rgba(0, 0, 0, 0.2);-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;-webkit-box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);-moz-box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);white-space:normal;}.popover.top{margin-top:-10px;} +.popover.right{margin-left:10px;} +.popover.bottom{margin-top:10px;} +.popover.left{margin-left:-10px;} +.popover-title{margin:0;padding:8px 14px;font-size:14px;font-weight:normal;line-height:18px;background-color:#f7f7f7;border-bottom:1px solid #ebebeb;-webkit-border-radius:5px 5px 0 0;-moz-border-radius:5px 5px 0 0;border-radius:5px 5px 0 0;}.popover-title:empty{display:none;} +.popover-content{padding:9px 14px;} +.popover .arrow,.popover .arrow:after{position:absolute;display:block;width:0;height:0;border-color:transparent;border-style:solid;} +.popover .arrow{border-width:11px;} +.popover .arrow:after{border-width:10px;content:"";} +.popover.top .arrow{left:50%;margin-left:-11px;border-bottom-width:0;border-top-color:#999;border-top-color:rgba(0, 0, 0, 0.25);bottom:-11px;}.popover.top .arrow:after{bottom:1px;margin-left:-10px;border-bottom-width:0;border-top-color:#ffffff;} +.popover.right .arrow{top:50%;left:-11px;margin-top:-11px;border-left-width:0;border-right-color:#999;border-right-color:rgba(0, 0, 0, 0.25);}.popover.right .arrow:after{left:1px;bottom:-10px;border-left-width:0;border-right-color:#ffffff;} +.popover.bottom .arrow{left:50%;margin-left:-11px;border-top-width:0;border-bottom-color:#999;border-bottom-color:rgba(0, 0, 0, 0.25);top:-11px;}.popover.bottom .arrow:after{top:1px;margin-left:-10px;border-top-width:0;border-bottom-color:#ffffff;} +.popover.left .arrow{top:50%;right:-11px;margin-top:-11px;border-right-width:0;border-left-color:#999;border-left-color:rgba(0, 0, 0, 0.25);}.popover.left .arrow:after{right:1px;border-right-width:0;border-left-color:#ffffff;bottom:-10px;} +.modal-backdrop{position:fixed;top:0;right:0;bottom:0;left:0;z-index:1040;background-color:#000000;}.modal-backdrop.fade{opacity:0;} +.modal-backdrop,.modal-backdrop.fade.in{opacity:0.8;filter:alpha(opacity=80);} +.modal{position:fixed;top:10%;left:50%;z-index:1050;width:560px;margin-left:-280px;background-color:#ffffff;border:1px solid #999;border:1px solid rgba(0, 0, 0, 0.3);*border:1px solid #999;-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;-webkit-box-shadow:0 3px 7px rgba(0, 0, 0, 0.3);-moz-box-shadow:0 3px 7px rgba(0, 0, 0, 0.3);box-shadow:0 3px 7px rgba(0, 0, 0, 0.3);-webkit-background-clip:padding-box;-moz-background-clip:padding-box;background-clip:padding-box;outline:none;}.modal.fade{-webkit-transition:opacity .3s linear, top .3s ease-out;-moz-transition:opacity .3s linear, top .3s ease-out;-o-transition:opacity .3s linear, top .3s ease-out;transition:opacity .3s linear, top .3s ease-out;top:-25%;} +.modal.fade.in{top:10%;} +.modal-header{padding:9px 15px;border-bottom:1px solid #eee;}.modal-header .close{margin-top:2px;} +.modal-header h3{margin:0;line-height:30px;} +.modal-body{position:relative;overflow-y:auto;max-height:400px;padding:15px;} +.modal-form{margin-bottom:0;} +.modal-footer{padding:14px 15px 15px;margin-bottom:0;text-align:right;background-color:#f5f5f5;border-top:1px solid #ddd;-webkit-border-radius:0 0 6px 6px;-moz-border-radius:0 0 6px 6px;border-radius:0 0 6px 6px;-webkit-box-shadow:inset 0 1px 0 #ffffff;-moz-box-shadow:inset 0 1px 0 #ffffff;box-shadow:inset 0 1px 0 #ffffff;*zoom:1;}.modal-footer:before,.modal-footer:after{display:table;content:"";line-height:0;} +.modal-footer:after{clear:both;} +.modal-footer .btn+.btn{margin-left:5px;margin-bottom:0;} +.modal-footer .btn-group .btn+.btn{margin-left:-1px;} +.modal-footer .btn-block+.btn-block{margin-left:0;} +.dropup,.dropdown{position:relative;} +.dropdown-toggle{*margin-bottom:-3px;} +.dropdown-toggle:active,.open .dropdown-toggle{outline:0;} +.caret{display:inline-block;width:0;height:0;vertical-align:top;border-top:4px solid #000000;border-right:4px solid transparent;border-left:4px solid transparent;content:"";} +.dropdown .caret{margin-top:8px;margin-left:2px;} +.dropdown-menu{position:absolute;top:100%;left:0;z-index:1000;display:none;float:left;min-width:160px;padding:5px 0;margin:2px 0 0;list-style:none;background-color:#ffffff;border:1px solid #ccc;border:1px solid rgba(0, 0, 0, 0.2);*border-right-width:2px;*border-bottom-width:2px;-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;-webkit-box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);-moz-box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);-webkit-background-clip:padding-box;-moz-background-clip:padding;background-clip:padding-box;}.dropdown-menu.pull-right{right:0;left:auto;} +.dropdown-menu .divider{*width:100%;height:1px;margin:9px 1px;*margin:-5px 0 5px;overflow:hidden;background-color:#e5e5e5;border-bottom:1px solid #ffffff;} +.dropdown-menu>li>a{display:block;padding:3px 20px;clear:both;font-weight:normal;line-height:20px;color:#333333;white-space:nowrap;} +.dropdown-menu>li>a:hover,.dropdown-menu>li>a:focus,.dropdown-submenu:hover>a,.dropdown-submenu:focus>a{text-decoration:none;color:#ffffff;background-color:#0081c2;background-image:-moz-linear-gradient(top, #0088cc, #0077b3);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#0088cc), to(#0077b3));background-image:-webkit-linear-gradient(top, #0088cc, #0077b3);background-image:-o-linear-gradient(top, #0088cc, #0077b3);background-image:linear-gradient(to bottom, #0088cc, #0077b3);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff0088cc', endColorstr='#ff0077b3', GradientType=0);} +.dropdown-menu>.active>a,.dropdown-menu>.active>a:hover,.dropdown-menu>.active>a:focus{color:#ffffff;text-decoration:none;outline:0;background-color:#0081c2;background-image:-moz-linear-gradient(top, #0088cc, #0077b3);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#0088cc), to(#0077b3));background-image:-webkit-linear-gradient(top, #0088cc, #0077b3);background-image:-o-linear-gradient(top, #0088cc, #0077b3);background-image:linear-gradient(to bottom, #0088cc, #0077b3);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff0088cc', endColorstr='#ff0077b3', GradientType=0);} +.dropdown-menu>.disabled>a,.dropdown-menu>.disabled>a:hover,.dropdown-menu>.disabled>a:focus{color:#999999;} +.dropdown-menu>.disabled>a:hover,.dropdown-menu>.disabled>a:focus{text-decoration:none;background-color:transparent;background-image:none;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);cursor:default;} +.open{*z-index:1000;}.open>.dropdown-menu{display:block;} +.dropdown-backdrop{position:fixed;left:0;right:0;bottom:0;top:0;z-index:990;} +.pull-right>.dropdown-menu{right:0;left:auto;} +.dropup .caret,.navbar-fixed-bottom .dropdown .caret{border-top:0;border-bottom:4px solid #000000;content:"";} +.dropup .dropdown-menu,.navbar-fixed-bottom .dropdown .dropdown-menu{top:auto;bottom:100%;margin-bottom:1px;} +.dropdown-submenu{position:relative;} +.dropdown-submenu>.dropdown-menu{top:0;left:100%;margin-top:-6px;margin-left:-1px;-webkit-border-radius:0 6px 6px 6px;-moz-border-radius:0 6px 6px 6px;border-radius:0 6px 6px 6px;} +.dropdown-submenu:hover>.dropdown-menu{display:block;} +.dropup .dropdown-submenu>.dropdown-menu{top:auto;bottom:0;margin-top:0;margin-bottom:-2px;-webkit-border-radius:5px 5px 5px 0;-moz-border-radius:5px 5px 5px 0;border-radius:5px 5px 5px 0;} +.dropdown-submenu>a:after{display:block;content:" ";float:right;width:0;height:0;border-color:transparent;border-style:solid;border-width:5px 0 5px 5px;border-left-color:#cccccc;margin-top:5px;margin-right:-10px;} +.dropdown-submenu:hover>a:after{border-left-color:#ffffff;} +.dropdown-submenu.pull-left{float:none;}.dropdown-submenu.pull-left>.dropdown-menu{left:-100%;margin-left:10px;-webkit-border-radius:6px 0 6px 6px;-moz-border-radius:6px 0 6px 6px;border-radius:6px 0 6px 6px;} +.dropdown .dropdown-menu .nav-header{padding-left:20px;padding-right:20px;} +.typeahead{z-index:1051;margin-top:2px;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} +.accordion{margin-bottom:20px;} +.accordion-group{margin-bottom:2px;border:1px solid #e5e5e5;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} +.accordion-heading{border-bottom:0;} +.accordion-heading .accordion-toggle{display:block;padding:8px 15px;} +.accordion-toggle{cursor:pointer;} +.accordion-inner{padding:9px 15px;border-top:1px solid #e5e5e5;} +.carousel{position:relative;margin-bottom:20px;line-height:1;} +.carousel-inner{overflow:hidden;width:100%;position:relative;} +.carousel-inner>.item{display:none;position:relative;-webkit-transition:0.6s ease-in-out left;-moz-transition:0.6s ease-in-out left;-o-transition:0.6s ease-in-out left;transition:0.6s ease-in-out left;}.carousel-inner>.item>img,.carousel-inner>.item>a>img{display:block;line-height:1;} +.carousel-inner>.active,.carousel-inner>.next,.carousel-inner>.prev{display:block;} +.carousel-inner>.active{left:0;} +.carousel-inner>.next,.carousel-inner>.prev{position:absolute;top:0;width:100%;} +.carousel-inner>.next{left:100%;} +.carousel-inner>.prev{left:-100%;} +.carousel-inner>.next.left,.carousel-inner>.prev.right{left:0;} +.carousel-inner>.active.left{left:-100%;} +.carousel-inner>.active.right{left:100%;} +.carousel-control{position:absolute;top:40%;left:15px;width:40px;height:40px;margin-top:-20px;font-size:60px;font-weight:100;line-height:30px;color:#ffffff;text-align:center;background:#222222;border:3px solid #ffffff;-webkit-border-radius:23px;-moz-border-radius:23px;border-radius:23px;opacity:0.5;filter:alpha(opacity=50);}.carousel-control.right{left:auto;right:15px;} +.carousel-control:hover,.carousel-control:focus{color:#ffffff;text-decoration:none;opacity:0.9;filter:alpha(opacity=90);} +.carousel-indicators{position:absolute;top:15px;right:15px;z-index:5;margin:0;list-style:none;}.carousel-indicators li{display:block;float:left;width:10px;height:10px;margin-left:5px;text-indent:-999px;background-color:#ccc;background-color:rgba(255, 255, 255, 0.25);border-radius:5px;} +.carousel-indicators .active{background-color:#fff;} +.carousel-caption{position:absolute;left:0;right:0;bottom:0;padding:15px;background:#333333;background:rgba(0, 0, 0, 0.75);} +.carousel-caption h4,.carousel-caption p{color:#ffffff;line-height:20px;} +.carousel-caption h4{margin:0 0 5px;} +.carousel-caption p{margin-bottom:0;} +.well{min-height:20px;padding:19px;margin-bottom:20px;background-color:#f5f5f5;border:1px solid #e3e3e3;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.05);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.05);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.05);}.well blockquote{border-color:#ddd;border-color:rgba(0, 0, 0, 0.15);} +.well-large{padding:24px;-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;} +.well-small{padding:9px;-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} +.close{float:right;font-size:20px;font-weight:bold;line-height:20px;color:#000000;text-shadow:0 1px 0 #ffffff;opacity:0.2;filter:alpha(opacity=20);}.close:hover,.close:focus{color:#000000;text-decoration:none;cursor:pointer;opacity:0.4;filter:alpha(opacity=40);} +button.close{padding:0;cursor:pointer;background:transparent;border:0;-webkit-appearance:none;} +.pull-right{float:right;} +.pull-left{float:left;} +.hide{display:none;} +.show{display:block;} +.invisible{visibility:hidden;} +.affix{position:fixed;} +.fade{opacity:0;-webkit-transition:opacity 0.15s linear;-moz-transition:opacity 0.15s linear;-o-transition:opacity 0.15s linear;transition:opacity 0.15s linear;}.fade.in{opacity:1;} +.collapse{position:relative;height:0;overflow:hidden;-webkit-transition:height 0.35s ease;-moz-transition:height 0.35s ease;-o-transition:height 0.35s ease;transition:height 0.35s ease;}.collapse.in{height:auto;} +@-ms-viewport{width:device-width;}.hidden{display:none;visibility:hidden;} +.visible-phone{display:none !important;} +.visible-tablet{display:none !important;} +.hidden-desktop{display:none !important;} +.visible-desktop{display:inherit !important;} +@media (min-width:768px) and (max-width:979px){.hidden-desktop{display:inherit !important;} .visible-desktop{display:none !important ;} .visible-tablet{display:inherit !important;} .hidden-tablet{display:none !important;}}@media (max-width:767px){.hidden-desktop{display:inherit !important;} .visible-desktop{display:none !important;} .visible-phone{display:inherit !important;} .hidden-phone{display:none !important;}}.visible-print{display:none !important;} +@media print{.visible-print{display:inherit !important;} .hidden-print{display:none !important;}}@media (max-width:767px){body{padding-left:20px;padding-right:20px;} .navbar-fixed-top,.navbar-fixed-bottom,.navbar-static-top{margin-left:-20px;margin-right:-20px;} .container-fluid{padding:0;} .dl-horizontal dt{float:none;clear:none;width:auto;text-align:left;} .dl-horizontal dd{margin-left:0;} .container{width:auto;} .row-fluid{width:100%;} .row,.thumbnails{margin-left:0;} .thumbnails>li{float:none;margin-left:0;} [class*="span"],.uneditable-input[class*="span"],.row-fluid [class*="span"]{float:none;display:block;width:100%;margin-left:0;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;} .span12,.row-fluid .span12{width:100%;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;} .row-fluid [class*="offset"]:first-child{margin-left:0;} .input-large,.input-xlarge,.input-xxlarge,input[class*="span"],select[class*="span"],textarea[class*="span"],.uneditable-input{display:block;width:100%;min-height:30px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;} .input-prepend input,.input-append input,.input-prepend input[class*="span"],.input-append input[class*="span"]{display:inline-block;width:auto;} .controls-row [class*="span"]+[class*="span"]{margin-left:0;} .modal{position:fixed;top:20px;left:20px;right:20px;width:auto;margin:0;}.modal.fade{top:-100px;} .modal.fade.in{top:20px;}}@media (max-width:480px){.nav-collapse{-webkit-transform:translate3d(0, 0, 0);} .page-header h1 small{display:block;line-height:20px;} input[type="checkbox"],input[type="radio"]{border:1px solid #ccc;} .form-horizontal .control-label{float:none;width:auto;padding-top:0;text-align:left;} .form-horizontal .controls{margin-left:0;} .form-horizontal .control-list{padding-top:0;} .form-horizontal .form-actions{padding-left:10px;padding-right:10px;} .media .pull-left,.media .pull-right{float:none;display:block;margin-bottom:10px;} .media-object{margin-right:0;margin-left:0;} .modal{top:10px;left:10px;right:10px;} .modal-header .close{padding:10px;margin:-10px;} .carousel-caption{position:static;}}@media (min-width:768px) and (max-width:979px){.row{margin-left:-20px;*zoom:1;}.row:before,.row:after{display:table;content:"";line-height:0;} .row:after{clear:both;} [class*="span"]{float:left;min-height:1px;margin-left:20px;} .container,.navbar-static-top .container,.navbar-fixed-top .container,.navbar-fixed-bottom .container{width:724px;} .span12{width:724px;} .span11{width:662px;} .span10{width:600px;} .span9{width:538px;} .span8{width:476px;} .span7{width:414px;} .span6{width:352px;} .span5{width:290px;} .span4{width:228px;} .span3{width:166px;} .span2{width:104px;} .span1{width:42px;} .offset12{margin-left:764px;} .offset11{margin-left:702px;} .offset10{margin-left:640px;} .offset9{margin-left:578px;} .offset8{margin-left:516px;} .offset7{margin-left:454px;} .offset6{margin-left:392px;} .offset5{margin-left:330px;} .offset4{margin-left:268px;} .offset3{margin-left:206px;} .offset2{margin-left:144px;} .offset1{margin-left:82px;} .row-fluid{width:100%;*zoom:1;}.row-fluid:before,.row-fluid:after{display:table;content:"";line-height:0;} .row-fluid:after{clear:both;} .row-fluid [class*="span"]{display:block;width:100%;min-height:30px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;float:left;margin-left:2.7624309392265194%;*margin-left:2.709239449864817%;} .row-fluid [class*="span"]:first-child{margin-left:0;} .row-fluid .controls-row [class*="span"]+[class*="span"]{margin-left:2.7624309392265194%;} .row-fluid .span12{width:100%;*width:99.94680851063829%;} .row-fluid .span11{width:91.43646408839778%;*width:91.38327259903608%;} .row-fluid .span10{width:82.87292817679558%;*width:82.81973668743387%;} .row-fluid .span9{width:74.30939226519337%;*width:74.25620077583166%;} .row-fluid .span8{width:65.74585635359117%;*width:65.69266486422946%;} .row-fluid .span7{width:57.18232044198895%;*width:57.12912895262725%;} .row-fluid .span6{width:48.61878453038674%;*width:48.56559304102504%;} .row-fluid .span5{width:40.05524861878453%;*width:40.00205712942283%;} .row-fluid .span4{width:31.491712707182323%;*width:31.43852121782062%;} .row-fluid .span3{width:22.92817679558011%;*width:22.87498530621841%;} .row-fluid .span2{width:14.3646408839779%;*width:14.311449394616199%;} .row-fluid .span1{width:5.801104972375691%;*width:5.747913483013988%;} .row-fluid .offset12{margin-left:105.52486187845304%;*margin-left:105.41847889972962%;} .row-fluid .offset12:first-child{margin-left:102.76243093922652%;*margin-left:102.6560479605031%;} .row-fluid .offset11{margin-left:96.96132596685082%;*margin-left:96.8549429881274%;} .row-fluid .offset11:first-child{margin-left:94.1988950276243%;*margin-left:94.09251204890089%;} .row-fluid .offset10{margin-left:88.39779005524862%;*margin-left:88.2914070765252%;} .row-fluid .offset10:first-child{margin-left:85.6353591160221%;*margin-left:85.52897613729868%;} .row-fluid .offset9{margin-left:79.8342541436464%;*margin-left:79.72787116492299%;} .row-fluid .offset9:first-child{margin-left:77.07182320441989%;*margin-left:76.96544022569647%;} .row-fluid .offset8{margin-left:71.2707182320442%;*margin-left:71.16433525332079%;} .row-fluid .offset8:first-child{margin-left:68.50828729281768%;*margin-left:68.40190431409427%;} .row-fluid .offset7{margin-left:62.70718232044199%;*margin-left:62.600799341718584%;} .row-fluid .offset7:first-child{margin-left:59.94475138121547%;*margin-left:59.838368402492065%;} .row-fluid .offset6{margin-left:54.14364640883978%;*margin-left:54.037263430116376%;} .row-fluid .offset6:first-child{margin-left:51.38121546961326%;*margin-left:51.27483249088986%;} .row-fluid .offset5{margin-left:45.58011049723757%;*margin-left:45.47372751851417%;} .row-fluid .offset5:first-child{margin-left:42.81767955801105%;*margin-left:42.71129657928765%;} .row-fluid .offset4{margin-left:37.01657458563536%;*margin-left:36.91019160691196%;} .row-fluid .offset4:first-child{margin-left:34.25414364640884%;*margin-left:34.14776066768544%;} .row-fluid .offset3{margin-left:28.45303867403315%;*margin-left:28.346655695309746%;} .row-fluid .offset3:first-child{margin-left:25.69060773480663%;*margin-left:25.584224756083227%;} .row-fluid .offset2{margin-left:19.88950276243094%;*margin-left:19.783119783707537%;} .row-fluid .offset2:first-child{margin-left:17.12707182320442%;*margin-left:17.02068884448102%;} .row-fluid .offset1{margin-left:11.32596685082873%;*margin-left:11.219583872105325%;} .row-fluid .offset1:first-child{margin-left:8.56353591160221%;*margin-left:8.457152932878806%;} input,textarea,.uneditable-input{margin-left:0;} .controls-row [class*="span"]+[class*="span"]{margin-left:20px;} input.span12,textarea.span12,.uneditable-input.span12{width:710px;} input.span11,textarea.span11,.uneditable-input.span11{width:648px;} input.span10,textarea.span10,.uneditable-input.span10{width:586px;} input.span9,textarea.span9,.uneditable-input.span9{width:524px;} input.span8,textarea.span8,.uneditable-input.span8{width:462px;} input.span7,textarea.span7,.uneditable-input.span7{width:400px;} input.span6,textarea.span6,.uneditable-input.span6{width:338px;} input.span5,textarea.span5,.uneditable-input.span5{width:276px;} input.span4,textarea.span4,.uneditable-input.span4{width:214px;} input.span3,textarea.span3,.uneditable-input.span3{width:152px;} input.span2,textarea.span2,.uneditable-input.span2{width:90px;} input.span1,textarea.span1,.uneditable-input.span1{width:28px;}}@media (min-width:1200px){.row{margin-left:-30px;*zoom:1;}.row:before,.row:after{display:table;content:"";line-height:0;} .row:after{clear:both;} [class*="span"]{float:left;min-height:1px;margin-left:30px;} .container,.navbar-static-top .container,.navbar-fixed-top .container,.navbar-fixed-bottom .container{width:1170px;} .span12{width:1170px;} .span11{width:1070px;} .span10{width:970px;} .span9{width:870px;} .span8{width:770px;} .span7{width:670px;} .span6{width:570px;} .span5{width:470px;} .span4{width:370px;} .span3{width:270px;} .span2{width:170px;} .span1{width:70px;} .offset12{margin-left:1230px;} .offset11{margin-left:1130px;} .offset10{margin-left:1030px;} .offset9{margin-left:930px;} .offset8{margin-left:830px;} .offset7{margin-left:730px;} .offset6{margin-left:630px;} .offset5{margin-left:530px;} .offset4{margin-left:430px;} .offset3{margin-left:330px;} .offset2{margin-left:230px;} .offset1{margin-left:130px;} .row-fluid{width:100%;*zoom:1;}.row-fluid:before,.row-fluid:after{display:table;content:"";line-height:0;} .row-fluid:after{clear:both;} .row-fluid [class*="span"]{display:block;width:100%;min-height:30px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;float:left;margin-left:2.564102564102564%;*margin-left:2.5109110747408616%;} .row-fluid [class*="span"]:first-child{margin-left:0;} .row-fluid .controls-row [class*="span"]+[class*="span"]{margin-left:2.564102564102564%;} .row-fluid .span12{width:100%;*width:99.94680851063829%;} .row-fluid .span11{width:91.45299145299145%;*width:91.39979996362975%;} .row-fluid .span10{width:82.90598290598291%;*width:82.8527914166212%;} .row-fluid .span9{width:74.35897435897436%;*width:74.30578286961266%;} .row-fluid .span8{width:65.81196581196582%;*width:65.75877432260411%;} .row-fluid .span7{width:57.26495726495726%;*width:57.21176577559556%;} .row-fluid .span6{width:48.717948717948715%;*width:48.664757228587014%;} .row-fluid .span5{width:40.17094017094017%;*width:40.11774868157847%;} .row-fluid .span4{width:31.623931623931625%;*width:31.570740134569924%;} .row-fluid .span3{width:23.076923076923077%;*width:23.023731587561375%;} .row-fluid .span2{width:14.52991452991453%;*width:14.476723040552828%;} .row-fluid .span1{width:5.982905982905983%;*width:5.929714493544281%;} .row-fluid .offset12{margin-left:105.12820512820512%;*margin-left:105.02182214948171%;} .row-fluid .offset12:first-child{margin-left:102.56410256410257%;*margin-left:102.45771958537915%;} .row-fluid .offset11{margin-left:96.58119658119658%;*margin-left:96.47481360247316%;} .row-fluid .offset11:first-child{margin-left:94.01709401709402%;*margin-left:93.91071103837061%;} .row-fluid .offset10{margin-left:88.03418803418803%;*margin-left:87.92780505546462%;} .row-fluid .offset10:first-child{margin-left:85.47008547008548%;*margin-left:85.36370249136206%;} .row-fluid .offset9{margin-left:79.48717948717949%;*margin-left:79.38079650845607%;} .row-fluid .offset9:first-child{margin-left:76.92307692307693%;*margin-left:76.81669394435352%;} .row-fluid .offset8{margin-left:70.94017094017094%;*margin-left:70.83378796144753%;} .row-fluid .offset8:first-child{margin-left:68.37606837606839%;*margin-left:68.26968539734497%;} .row-fluid .offset7{margin-left:62.393162393162385%;*margin-left:62.28677941443899%;} .row-fluid .offset7:first-child{margin-left:59.82905982905982%;*margin-left:59.72267685033642%;} .row-fluid .offset6{margin-left:53.84615384615384%;*margin-left:53.739770867430444%;} .row-fluid .offset6:first-child{margin-left:51.28205128205128%;*margin-left:51.175668303327875%;} .row-fluid .offset5{margin-left:45.299145299145295%;*margin-left:45.1927623204219%;} .row-fluid .offset5:first-child{margin-left:42.73504273504273%;*margin-left:42.62865975631933%;} .row-fluid .offset4{margin-left:36.75213675213675%;*margin-left:36.645753773413354%;} .row-fluid .offset4:first-child{margin-left:34.18803418803419%;*margin-left:34.081651209310785%;} .row-fluid .offset3{margin-left:28.205128205128204%;*margin-left:28.0987452264048%;} .row-fluid .offset3:first-child{margin-left:25.641025641025642%;*margin-left:25.53464266230224%;} .row-fluid .offset2{margin-left:19.65811965811966%;*margin-left:19.551736679396257%;} .row-fluid .offset2:first-child{margin-left:17.094017094017094%;*margin-left:16.98763411529369%;} .row-fluid .offset1{margin-left:11.11111111111111%;*margin-left:11.004728132387708%;} .row-fluid .offset1:first-child{margin-left:8.547008547008547%;*margin-left:8.440625568285142%;} input,textarea,.uneditable-input{margin-left:0;} .controls-row [class*="span"]+[class*="span"]{margin-left:30px;} input.span12,textarea.span12,.uneditable-input.span12{width:1156px;} input.span11,textarea.span11,.uneditable-input.span11{width:1056px;} input.span10,textarea.span10,.uneditable-input.span10{width:956px;} input.span9,textarea.span9,.uneditable-input.span9{width:856px;} input.span8,textarea.span8,.uneditable-input.span8{width:756px;} input.span7,textarea.span7,.uneditable-input.span7{width:656px;} input.span6,textarea.span6,.uneditable-input.span6{width:556px;} input.span5,textarea.span5,.uneditable-input.span5{width:456px;} input.span4,textarea.span4,.uneditable-input.span4{width:356px;} input.span3,textarea.span3,.uneditable-input.span3{width:256px;} input.span2,textarea.span2,.uneditable-input.span2{width:156px;} input.span1,textarea.span1,.uneditable-input.span1{width:56px;} .thumbnails{margin-left:-30px;} .thumbnails>li{margin-left:30px;} .row-fluid .thumbnails{margin-left:0;}}@media (max-width:979px){body{padding-top:0;} .navbar-fixed-top,.navbar-fixed-bottom{position:static;} .navbar-fixed-top{margin-bottom:20px;} .navbar-fixed-bottom{margin-top:20px;} .navbar-fixed-top .navbar-inner,.navbar-fixed-bottom .navbar-inner{padding:5px;} .navbar .container{width:auto;padding:0;} .navbar .brand{padding-left:10px;padding-right:10px;margin:0 0 0 -5px;} .nav-collapse{clear:both;} .nav-collapse .nav{float:none;margin:0 0 10px;} .nav-collapse .nav>li{float:none;} .nav-collapse .nav>li>a{margin-bottom:2px;} .nav-collapse .nav>.divider-vertical{display:none;} .nav-collapse .nav .nav-header{color:#777777;text-shadow:none;} .nav-collapse .nav>li>a,.nav-collapse .dropdown-menu a{padding:9px 15px;font-weight:bold;color:#777777;-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} .nav-collapse .btn{padding:4px 10px 4px;font-weight:normal;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} .nav-collapse .dropdown-menu li+li a{margin-bottom:2px;} .nav-collapse .nav>li>a:hover,.nav-collapse .nav>li>a:focus,.nav-collapse .dropdown-menu a:hover,.nav-collapse .dropdown-menu a:focus{background-color:#f2f2f2;} .navbar-inverse .nav-collapse .nav>li>a,.navbar-inverse .nav-collapse .dropdown-menu a{color:#999999;} .navbar-inverse .nav-collapse .nav>li>a:hover,.navbar-inverse .nav-collapse .nav>li>a:focus,.navbar-inverse .nav-collapse .dropdown-menu a:hover,.navbar-inverse .nav-collapse .dropdown-menu a:focus{background-color:#111111;} .nav-collapse.in .btn-group{margin-top:5px;padding:0;} .nav-collapse .dropdown-menu{position:static;top:auto;left:auto;float:none;display:none;max-width:none;margin:0 15px;padding:0;background-color:transparent;border:none;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;-webkit-box-shadow:none;-moz-box-shadow:none;box-shadow:none;} .nav-collapse .open>.dropdown-menu{display:block;} .nav-collapse .dropdown-menu:before,.nav-collapse .dropdown-menu:after{display:none;} .nav-collapse .dropdown-menu .divider{display:none;} .nav-collapse .nav>li>.dropdown-menu:before,.nav-collapse .nav>li>.dropdown-menu:after{display:none;} .nav-collapse .navbar-form,.nav-collapse .navbar-search{float:none;padding:10px 15px;margin:10px 0;border-top:1px solid #f2f2f2;border-bottom:1px solid #f2f2f2;-webkit-box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.1);-moz-box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.1);box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.1);} .navbar-inverse .nav-collapse .navbar-form,.navbar-inverse .nav-collapse .navbar-search{border-top-color:#111111;border-bottom-color:#111111;} .navbar .nav-collapse .nav.pull-right{float:none;margin-left:0;} .nav-collapse,.nav-collapse.collapse{overflow:hidden;height:0;} .navbar .btn-navbar{display:block;} .navbar-static .navbar-inner{padding-left:10px;padding-right:10px;}}@media (min-width:980px){.nav-collapse.collapse{height:auto !important;overflow:visible !important;}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js new file mode 100644 index 0000000000..7abb9011cc --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -0,0 +1,495 @@ +/* + SortTable + version 2 + 7th April 2007 + Stuart Langridge, http://www.kryogenix.org/code/browser/sorttable/ + + Instructions: + Download this file + Add to your HTML + Add class="sortable" to any table you'd like to make sortable + Click on the headers to sort + + Thanks to many, many people for contributions and suggestions. + Licenced as X11: http://www.kryogenix.org/code/browser/licence.html + This basically means: do what you want with it. +*/ + + +var stIsIE = /*@cc_on!@*/false; + +sorttable = { + init: function() { + // quit if this function has already been called + if (arguments.callee.done) return; + // flag this function so we don't do the same thing twice + arguments.callee.done = true; + // kill the timer + if (_timer) clearInterval(_timer); + + if (!document.createElement || !document.getElementsByTagName) return; + + sorttable.DATE_RE = /^(\d\d?)[\/\.-](\d\d?)[\/\.-]((\d\d)?\d\d)$/; + + forEach(document.getElementsByTagName('table'), function(table) { + if (table.className.search(/\bsortable\b/) != -1) { + sorttable.makeSortable(table); + } + }); + + }, + + makeSortable: function(table) { + if (table.getElementsByTagName('thead').length == 0) { + // table doesn't have a tHead. Since it should have, create one and + // put the first table row in it. + the = document.createElement('thead'); + the.appendChild(table.rows[0]); + table.insertBefore(the,table.firstChild); + } + // Safari doesn't support table.tHead, sigh + if (table.tHead == null) table.tHead = table.getElementsByTagName('thead')[0]; + + if (table.tHead.rows.length != 1) return; // can't cope with two header rows + + // Sorttable v1 put rows with a class of "sortbottom" at the bottom (as + // "total" rows, for example). This is B&R, since what you're supposed + // to do is put them in a tfoot. So, if there are sortbottom rows, + // for backwards compatibility, move them to tfoot (creating it if needed). + sortbottomrows = []; + for (var i=0; i5' : ' ▴'; + this.appendChild(sortrevind); + return; + } + if (this.className.search(/\bsorttable_sorted_reverse\b/) != -1) { + // if we're already sorted by this column in reverse, just + // re-reverse the table, which is quicker + sorttable.reverse(this.sorttable_tbody); + this.className = this.className.replace('sorttable_sorted_reverse', + 'sorttable_sorted'); + this.removeChild(document.getElementById('sorttable_sortrevind')); + sortfwdind = document.createElement('span'); + sortfwdind.id = "sorttable_sortfwdind"; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + this.appendChild(sortfwdind); + return; + } + + // remove sorttable_sorted classes + theadrow = this.parentNode; + forEach(theadrow.childNodes, function(cell) { + if (cell.nodeType == 1) { // an element + cell.className = cell.className.replace('sorttable_sorted_reverse',''); + cell.className = cell.className.replace('sorttable_sorted',''); + } + }); + sortfwdind = document.getElementById('sorttable_sortfwdind'); + if (sortfwdind) { sortfwdind.parentNode.removeChild(sortfwdind); } + sortrevind = document.getElementById('sorttable_sortrevind'); + if (sortrevind) { sortrevind.parentNode.removeChild(sortrevind); } + + this.className += ' sorttable_sorted'; + sortfwdind = document.createElement('span'); + sortfwdind.id = "sorttable_sortfwdind"; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + this.appendChild(sortfwdind); + + // build an array to sort. This is a Schwartzian transform thing, + // i.e., we "decorate" each row with the actual sort key, + // sort based on the sort keys, and then put the rows back in order + // which is a lot faster because you only do getInnerText once per row + row_array = []; + col = this.sorttable_columnindex; + rows = this.sorttable_tbody.rows; + for (var j=0; j 12) { + // definitely dd/mm + return sorttable.sort_ddmm; + } else if (second > 12) { + return sorttable.sort_mmdd; + } else { + // looks like a date, but we can't tell which, so assume + // that it's dd/mm (English imperialism!) and keep looking + sortfn = sorttable.sort_ddmm; + } + } + } + } + return sortfn; + }, + + getInnerText: function(node) { + // gets the text we want to use for sorting for a cell. + // strips leading and trailing whitespace. + // this is *not* a generic getInnerText function; it's special to sorttable. + // for example, you can override the cell text with a customkey attribute. + // it also gets .value for fields. + + if (!node) return ""; + + hasInputs = (typeof node.getElementsByTagName == 'function') && + node.getElementsByTagName('input').length; + + if (node.getAttribute("sorttable_customkey") != null) { + return node.getAttribute("sorttable_customkey"); + } + else if (typeof node.textContent != 'undefined' && !hasInputs) { + return node.textContent.replace(/^\s+|\s+$/g, ''); + } + else if (typeof node.innerText != 'undefined' && !hasInputs) { + return node.innerText.replace(/^\s+|\s+$/g, ''); + } + else if (typeof node.text != 'undefined' && !hasInputs) { + return node.text.replace(/^\s+|\s+$/g, ''); + } + else { + switch (node.nodeType) { + case 3: + if (node.nodeName.toLowerCase() == 'input') { + return node.value.replace(/^\s+|\s+$/g, ''); + } + case 4: + return node.nodeValue.replace(/^\s+|\s+$/g, ''); + break; + case 1: + case 11: + var innerText = ''; + for (var i = 0; i < node.childNodes.length; i++) { + innerText += sorttable.getInnerText(node.childNodes[i]); + } + return innerText.replace(/^\s+|\s+$/g, ''); + break; + default: + return ''; + } + } + }, + + reverse: function(tbody) { + // reverse the rows in a tbody + newrows = []; + for (var i=0; i=0; i--) { + tbody.appendChild(newrows[i]); + } + delete newrows; + }, + + /* sort functions + each sort function takes two parameters, a and b + you are comparing a[0] and b[0] */ + sort_numeric: function(a,b) { + aa = parseFloat(a[0].replace(/[^0-9.-]/g,'')); + if (isNaN(aa)) aa = 0; + bb = parseFloat(b[0].replace(/[^0-9.-]/g,'')); + if (isNaN(bb)) bb = 0; + return aa-bb; + }, + sort_alpha: function(a,b) { + if (a[0]==b[0]) return 0; + if (a[0] 0 ) { + var q = list[i]; list[i] = list[i+1]; list[i+1] = q; + swap = true; + } + } // for + t--; + + if (!swap) break; + + for(var i = t; i > b; --i) { + if ( comp_func(list[i], list[i-1]) < 0 ) { + var q = list[i]; list[i] = list[i-1]; list[i-1] = q; + swap = true; + } + } // for + b++; + + } // while(swap) + } +} + +/* ****************************************************************** + Supporting functions: bundled here to avoid depending on a library + ****************************************************************** */ + +// Dean Edwards/Matthias Miller/John Resig + +/* for Mozilla/Opera9 */ +if (document.addEventListener) { + document.addEventListener("DOMContentLoaded", sorttable.init, false); +} + +/* for Internet Explorer */ +/*@cc_on @*/ +/*@if (@_win32) + document.write(" to your HTML - Add class="sortable" to any table you'd like to make sortable - Click on the headers to sort - - Thanks to many, many people for contributions and suggestions. - Licenced as X11: http://www.kryogenix.org/code/browser/licence.html - This basically means: do what you want with it. -*/ - - -var stIsIE = /*@cc_on!@*/false; - -sorttable = { - init: function() { - // quit if this function has already been called - if (arguments.callee.done) return; - // flag this function so we don't do the same thing twice - arguments.callee.done = true; - // kill the timer - if (_timer) clearInterval(_timer); - - if (!document.createElement || !document.getElementsByTagName) return; - - sorttable.DATE_RE = /^(\d\d?)[\/\.-](\d\d?)[\/\.-]((\d\d)?\d\d)$/; - - forEach(document.getElementsByTagName('table'), function(table) { - if (table.className.search(/\bsortable\b/) != -1) { - sorttable.makeSortable(table); - } - }); - - }, - - makeSortable: function(table) { - if (table.getElementsByTagName('thead').length == 0) { - // table doesn't have a tHead. Since it should have, create one and - // put the first table row in it. - the = document.createElement('thead'); - the.appendChild(table.rows[0]); - table.insertBefore(the,table.firstChild); - } - // Safari doesn't support table.tHead, sigh - if (table.tHead == null) table.tHead = table.getElementsByTagName('thead')[0]; - - if (table.tHead.rows.length != 1) return; // can't cope with two header rows - - // Sorttable v1 put rows with a class of "sortbottom" at the bottom (as - // "total" rows, for example). This is B&R, since what you're supposed - // to do is put them in a tfoot. So, if there are sortbottom rows, - // for backwards compatibility, move them to tfoot (creating it if needed). - sortbottomrows = []; - for (var i=0; i5' : ' ▴'; - this.appendChild(sortrevind); - return; - } - if (this.className.search(/\bsorttable_sorted_reverse\b/) != -1) { - // if we're already sorted by this column in reverse, just - // re-reverse the table, which is quicker - sorttable.reverse(this.sorttable_tbody); - this.className = this.className.replace('sorttable_sorted_reverse', - 'sorttable_sorted'); - this.removeChild(document.getElementById('sorttable_sortrevind')); - sortfwdind = document.createElement('span'); - sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; - this.appendChild(sortfwdind); - return; - } - - // remove sorttable_sorted classes - theadrow = this.parentNode; - forEach(theadrow.childNodes, function(cell) { - if (cell.nodeType == 1) { // an element - cell.className = cell.className.replace('sorttable_sorted_reverse',''); - cell.className = cell.className.replace('sorttable_sorted',''); - } - }); - sortfwdind = document.getElementById('sorttable_sortfwdind'); - if (sortfwdind) { sortfwdind.parentNode.removeChild(sortfwdind); } - sortrevind = document.getElementById('sorttable_sortrevind'); - if (sortrevind) { sortrevind.parentNode.removeChild(sortrevind); } - - this.className += ' sorttable_sorted'; - sortfwdind = document.createElement('span'); - sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; - this.appendChild(sortfwdind); - - // build an array to sort. This is a Schwartzian transform thing, - // i.e., we "decorate" each row with the actual sort key, - // sort based on the sort keys, and then put the rows back in order - // which is a lot faster because you only do getInnerText once per row - row_array = []; - col = this.sorttable_columnindex; - rows = this.sorttable_tbody.rows; - for (var j=0; j 12) { - // definitely dd/mm - return sorttable.sort_ddmm; - } else if (second > 12) { - return sorttable.sort_mmdd; - } else { - // looks like a date, but we can't tell which, so assume - // that it's dd/mm (English imperialism!) and keep looking - sortfn = sorttable.sort_ddmm; - } - } - } - } - return sortfn; - }, - - getInnerText: function(node) { - // gets the text we want to use for sorting for a cell. - // strips leading and trailing whitespace. - // this is *not* a generic getInnerText function; it's special to sorttable. - // for example, you can override the cell text with a customkey attribute. - // it also gets .value for fields. - - if (!node) return ""; - - hasInputs = (typeof node.getElementsByTagName == 'function') && - node.getElementsByTagName('input').length; - - if (node.getAttribute("sorttable_customkey") != null) { - return node.getAttribute("sorttable_customkey"); - } - else if (typeof node.textContent != 'undefined' && !hasInputs) { - return node.textContent.replace(/^\s+|\s+$/g, ''); - } - else if (typeof node.innerText != 'undefined' && !hasInputs) { - return node.innerText.replace(/^\s+|\s+$/g, ''); - } - else if (typeof node.text != 'undefined' && !hasInputs) { - return node.text.replace(/^\s+|\s+$/g, ''); - } - else { - switch (node.nodeType) { - case 3: - if (node.nodeName.toLowerCase() == 'input') { - return node.value.replace(/^\s+|\s+$/g, ''); - } - case 4: - return node.nodeValue.replace(/^\s+|\s+$/g, ''); - break; - case 1: - case 11: - var innerText = ''; - for (var i = 0; i < node.childNodes.length; i++) { - innerText += sorttable.getInnerText(node.childNodes[i]); - } - return innerText.replace(/^\s+|\s+$/g, ''); - break; - default: - return ''; - } - } - }, - - reverse: function(tbody) { - // reverse the rows in a tbody - newrows = []; - for (var i=0; i=0; i--) { - tbody.appendChild(newrows[i]); - } - delete newrows; - }, - - /* sort functions - each sort function takes two parameters, a and b - you are comparing a[0] and b[0] */ - sort_numeric: function(a,b) { - aa = parseFloat(a[0].replace(/[^0-9.-]/g,'')); - if (isNaN(aa)) aa = 0; - bb = parseFloat(b[0].replace(/[^0-9.-]/g,'')); - if (isNaN(bb)) bb = 0; - return aa-bb; - }, - sort_alpha: function(a,b) { - if (a[0]==b[0]) return 0; - if (a[0] 0 ) { - var q = list[i]; list[i] = list[i+1]; list[i+1] = q; - swap = true; - } - } // for - t--; - - if (!swap) break; - - for(var i = t; i > b; --i) { - if ( comp_func(list[i], list[i-1]) < 0 ) { - var q = list[i]; list[i] = list[i-1]; list[i-1] = q; - swap = true; - } - } // for - b++; - - } // while(swap) - } -} - -/* ****************************************************************** - Supporting functions: bundled here to avoid depending on a library - ****************************************************************** */ - -// Dean Edwards/Matthias Miller/John Resig - -/* for Mozilla/Opera9 */ -if (document.addEventListener) { - document.addEventListener("DOMContentLoaded", sorttable.init, false); -} - -/* for Internet Explorer */ -/*@cc_on @*/ -/*@if (@_win32) - document.write(" + {sc.appName} - {title} + + + + +
+
+
+

+ {title} +

+
+
+ {content} +
+ + + } + + /** Returns a page with the spark css/js and a simple format. Used for scheduler UI. */ + def basicSparkPage(content: => Seq[Node], title: String): Seq[Node] = { + + + + + + + {title} + + +
+
+
+

+ + {title} +

+
+
+ {content} +
+ + + } + + /** Returns an HTML table constructed by generating a row for each object in a sequence. */ + def listingTable[T]( + headers: Seq[String], + makeRow: T => Seq[Node], + rows: Seq[T], + fixedWidth: Boolean = false): Seq[Node] = { + + val colWidth = 100.toDouble / headers.size + val colWidthAttr = if (fixedWidth) colWidth + "%" else "" + var tableClass = "table table-bordered table-striped table-condensed sortable" + if (fixedWidth) { + tableClass += " table-fixed" + } + + + {headers.map(h => )} + + {rows.map(r => makeRow(r))} + +
{h}
+ } +} diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala new file mode 100644 index 0000000000..0ecb22d2f9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.scheduler.cluster.SchedulingMode + + +/** + * Continuously generates jobs that expose various features of the WebUI (internal testing tool). + * + * Usage: ./run spark.ui.UIWorkloadGenerator [master] + */ +private[spark] object UIWorkloadGenerator { + val NUM_PARTITIONS = 100 + val INTER_JOB_WAIT_MS = 5000 + + def main(args: Array[String]) { + if (args.length < 2) { + println("usage: ./spark-class spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") + System.exit(1) + } + val master = args(0) + val schedulingMode = SchedulingMode.withName(args(1)) + val appName = "Spark UI Tester" + + if (schedulingMode == SchedulingMode.FAIR) { + System.setProperty("spark.cluster.schedulingmode", "FAIR") + } + val sc = new SparkContext(master, appName) + + def setProperties(s: String) = { + if(schedulingMode == SchedulingMode.FAIR) { + sc.setLocalProperty("spark.scheduler.cluster.fair.pool", s) + } + sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, s) + } + + val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) + def nextFloat() = (new Random()).nextFloat() + + val jobs = Seq[(String, () => Long)]( + ("Count", baseData.count), + ("Cache and Count", baseData.map(x => x).cache.count), + ("Single Shuffle", baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count), + ("Entirely failed phase", baseData.map(x => throw new Exception).count), + ("Partially failed phase", { + baseData.map{x => + val probFailure = (4.0 / NUM_PARTITIONS) + if (nextFloat() < probFailure) { + throw new Exception("This is a task failure") + } + 1 + }.count + }), + ("Partially failed phase (longer tasks)", { + baseData.map{x => + val probFailure = (4.0 / NUM_PARTITIONS) + if (nextFloat() < probFailure) { + Thread.sleep(100) + throw new Exception("This is a task failure") + } + 1 + }.count + }), + ("Job with delays", baseData.map(x => Thread.sleep(100)).count) + ) + + while (true) { + for ((desc, job) <- jobs) { + new Thread { + override def run() { + try { + setProperties(desc) + job() + println("Job funished: " + desc) + } catch { + case e: Exception => + println("Job Failed: " + desc) + } + } + }.start + Thread.sleep(INTER_JOB_WAIT_MS) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala new file mode 100644 index 0000000000..c5bf2acc9e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.env + +import javax.servlet.http.HttpServletRequest + +import scala.collection.JavaConversions._ +import scala.util.Properties +import scala.xml.Node + +import org.eclipse.jetty.server.Handler + +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.UIUtils +import org.apache.spark.ui.Page.Environment +import org.apache.spark.SparkContext + + +private[spark] class EnvironmentUI(sc: SparkContext) { + + def getHandlers = Seq[(String, Handler)]( + ("/environment", (request: HttpServletRequest) => envDetails(request)) + ) + + def envDetails(request: HttpServletRequest): Seq[Node] = { + val jvmInformation = Seq( + ("Java Version", "%s (%s)".format(Properties.javaVersion, Properties.javaVendor)), + ("Java Home", Properties.javaHome), + ("Scala Version", Properties.versionString), + ("Scala Home", Properties.scalaHome) + ).sorted + def jvmRow(kv: (String, String)) = {kv._1}{kv._2} + def jvmTable = + UIUtils.listingTable(Seq("Name", "Value"), jvmRow, jvmInformation, fixedWidth = true) + + val properties = System.getProperties.iterator.toSeq + val classPathProperty = properties.find { case (k, v) => + k.contains("java.class.path") + }.getOrElse(("", "")) + val sparkProperties = properties.filter(_._1.startsWith("spark")).sorted + val otherProperties = properties.diff(sparkProperties :+ classPathProperty).sorted + + val propertyHeaders = Seq("Name", "Value") + def propertyRow(kv: (String, String)) = {kv._1}{kv._2} + val sparkPropertyTable = + UIUtils.listingTable(propertyHeaders, propertyRow, sparkProperties, fixedWidth = true) + val otherPropertyTable = + UIUtils.listingTable(propertyHeaders, propertyRow, otherProperties, fixedWidth = true) + + val classPathEntries = classPathProperty._2 + .split(System.getProperty("path.separator", ":")) + .filterNot(e => e.isEmpty) + .map(e => (e, "System Classpath")) + val addedJars = sc.addedJars.iterator.toSeq.map{case (path, time) => (path, "Added By User")} + val addedFiles = sc.addedFiles.iterator.toSeq.map{case (path, time) => (path, "Added By User")} + val classPath = (addedJars ++ addedFiles ++ classPathEntries).sorted + + val classPathHeaders = Seq("Resource", "Source") + def classPathRow(data: (String, String)) = {data._1}{data._2} + val classPathTable = + UIUtils.listingTable(classPathHeaders, classPathRow, classPath, fixedWidth = true) + + val content = + +

Runtime Information

{jvmTable} +

Spark Properties

+ {sparkPropertyTable} +

System Properties

+ {otherPropertyTable} +

Classpath Entries

+ {classPathTable} +
+ + UIUtils.headerSparkPage(content, sc, "Environment", Environment) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala new file mode 100644 index 0000000000..efe6b474e0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -0,0 +1,136 @@ +package org.apache.spark.ui.exec + +import javax.servlet.http.HttpServletRequest + +import scala.collection.mutable.{HashMap, HashSet} +import scala.xml.Node + +import org.eclipse.jetty.server.Handler + +import org.apache.spark.{ExceptionFailure, Logging, Utils, SparkContext} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListenerTaskEnd, SparkListener} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.Page.Executors +import org.apache.spark.ui.UIUtils + + +private[spark] class ExecutorsUI(val sc: SparkContext) { + + private var _listener: Option[ExecutorsListener] = None + def listener = _listener.get + + def start() { + _listener = Some(new ExecutorsListener) + sc.addSparkListener(listener) + } + + def getHandlers = Seq[(String, Handler)]( + ("/executors", (request: HttpServletRequest) => render(request)) + ) + + def render(request: HttpServletRequest): Seq[Node] = { + val storageStatusList = sc.getExecutorStorageStatus + + val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_+_) + val memUsed = storageStatusList.map(_.memUsed()).fold(0L)(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_) + + val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used", + "Active tasks", "Failed tasks", "Complete tasks", "Total tasks") + + def execRow(kv: Seq[String]) = { + + {kv(0)} + {kv(1)} + {kv(2)} + + {Utils.bytesToString(kv(3).toLong)} / {Utils.bytesToString(kv(4).toLong)} + + + {Utils.bytesToString(kv(5).toLong)} + + {kv(6)} + {kv(7)} + {kv(8)} + {kv(9)} + + } + + val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b) + val execTable = UIUtils.listingTable(execHead, execRow, execInfo) + + val content = +
+
+
    +
  • Memory: + {Utils.bytesToString(memUsed)} Used + ({Utils.bytesToString(maxMem)} Total)
  • +
  • Disk: {Utils.bytesToString(diskSpaceUsed)} Used
  • +
+
+
+
+
+ {execTable} +
+
; + + UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors) + } + + def getExecInfo(a: Int): Seq[String] = { + val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId + val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort + val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString + val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString + val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString + val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString + val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0) + val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0) + val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0) + val totalTasks = activeTasks + failedTasks + completedTasks + + Seq( + execId, + hostPort, + rddBlocks, + memUsed, + maxMem, + diskUsed, + activeTasks.toString, + failedTasks.toString, + completedTasks.toString, + totalTasks.toString + ) + } + + private[spark] class ExecutorsListener extends SparkListener with Logging { + val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]() + val executorToTasksComplete = HashMap[String, Int]() + val executorToTasksFailed = HashMap[String, Int]() + + override def onTaskStart(taskStart: SparkListenerTaskStart) { + val eid = taskStart.taskInfo.executorId + val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) + activeTasks += taskStart.taskInfo + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val eid = taskEnd.taskInfo.executorId + val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) + activeTasks -= taskEnd.taskInfo + val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = + taskEnd.reason match { + case e: ExceptionFailure => + executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 + (Some(e), e.metrics) + case _ => + executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 + (None, Option(taskEnd.taskMetrics)) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala new file mode 100644 index 0000000000..3b428effaf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import javax.servlet.http.HttpServletRequest + +import scala.xml.{NodeSeq, Node} + +import org.apache.spark.scheduler.cluster.SchedulingMode +import org.apache.spark.ui.Page._ +import org.apache.spark.ui.UIUtils._ + + +/** Page showing list of all ongoing and recently finished stages and pools*/ +private[spark] class IndexPage(parent: JobProgressUI) { + def listener = parent.listener + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val activeStages = listener.activeStages.toSeq + val completedStages = listener.completedStages.reverse.toSeq + val failedStages = listener.failedStages.reverse.toSeq + val now = System.currentTimeMillis() + + var activeTime = 0L + for (tasks <- listener.stageToTasksActive.values; t <- tasks) { + activeTime += t.timeRunning(now) + } + + val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) + val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse, parent) + val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent) + + val pools = listener.sc.getAllPools + val poolTable = new PoolTable(pools, listener) + val summary: NodeSeq = +
+
    +
  • + Total Duration: + {parent.formatDuration(now - listener.sc.startTime)} +
  • +
  • Scheduling Mode: {parent.sc.getSchedulingMode}
  • +
  • + Active Stages: + {activeStages.size} +
  • +
  • + Completed Stages: + {completedStages.size} +
  • +
  • + Failed Stages: + {failedStages.size} +
  • +
+
+ + val content = summary ++ + {if (listener.sc.getSchedulingMode == SchedulingMode.FAIR) { +

{pools.size} Fair Scheduler Pools

++ poolTable.toNodeSeq + } else { + Seq() + }} ++ +

Active Stages ({activeStages.size})

++ + activeStagesTable.toNodeSeq++ +

Completed Stages ({completedStages.size})

++ + completedStagesTable.toNodeSeq++ +

Failed Stages ({failedStages.size})

++ + failedStagesTable.toNodeSeq + + headerSparkPage(content, parent.sc, "Spark Stages", Stages) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala new file mode 100644 index 0000000000..ae02226300 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -0,0 +1,156 @@ +package org.apache.spark.ui.jobs + +import scala.Seq +import scala.collection.mutable.{ListBuffer, HashMap, HashSet} + +import org.apache.spark.{ExceptionFailure, SparkContext, Success, Utils} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.executor.TaskMetrics +import collection.mutable + +/** + * Tracks task-level information to be displayed in the UI. + * + * All access to the data structures in this class must be synchronized on the + * class, since the UI thread and the DAGScheduler event loop may otherwise + * be reading/updating the internal data structures concurrently. + */ +private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener { + // How many stages to remember + val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt + val DEFAULT_POOL_NAME = "default" + + val stageToPool = new HashMap[Stage, String]() + val stageToDescription = new HashMap[Stage, String]() + val poolToActiveStages = new HashMap[String, HashSet[Stage]]() + + val activeStages = HashSet[Stage]() + val completedStages = ListBuffer[Stage]() + val failedStages = ListBuffer[Stage]() + + // Total metrics reflect metrics only for completed tasks + var totalTime = 0L + var totalShuffleRead = 0L + var totalShuffleWrite = 0L + + val stageToTime = HashMap[Int, Long]() + val stageToShuffleRead = HashMap[Int, Long]() + val stageToShuffleWrite = HashMap[Int, Long]() + val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]() + val stageToTasksComplete = HashMap[Int, Int]() + val stageToTasksFailed = HashMap[Int, Int]() + val stageToTaskInfos = + HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() + + override def onJobStart(jobStart: SparkListenerJobStart) {} + + override def onStageCompleted(stageCompleted: StageCompleted) = synchronized { + val stage = stageCompleted.stageInfo.stage + poolToActiveStages(stageToPool(stage)) -= stage + activeStages -= stage + completedStages += stage + trimIfNecessary(completedStages) + } + + /** If stages is too large, remove and garbage collect old stages */ + def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized { + if (stages.size > RETAINED_STAGES) { + val toRemove = RETAINED_STAGES / 10 + stages.takeRight(toRemove).foreach( s => { + stageToTaskInfos.remove(s.id) + stageToTime.remove(s.id) + stageToShuffleRead.remove(s.id) + stageToShuffleWrite.remove(s.id) + stageToTasksActive.remove(s.id) + stageToTasksComplete.remove(s.id) + stageToTasksFailed.remove(s.id) + stageToPool.remove(s) + if (stageToDescription.contains(s)) {stageToDescription.remove(s)} + }) + stages.trimEnd(toRemove) + } + } + + /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + val stage = stageSubmitted.stage + activeStages += stage + + val poolName = Option(stageSubmitted.properties).map { + p => p.getProperty("spark.scheduler.cluster.fair.pool", DEFAULT_POOL_NAME) + }.getOrElse(DEFAULT_POOL_NAME) + stageToPool(stage) = poolName + + val description = Option(stageSubmitted.properties).flatMap { + p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) + } + description.map(d => stageToDescription(stage) = d) + + val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]()) + stages += stage + } + + override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + val sid = taskStart.task.stageId + val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) + tasksActive += taskStart.taskInfo + val taskList = stageToTaskInfos.getOrElse( + sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) + taskList += ((taskStart.taskInfo, None, None)) + stageToTaskInfos(sid) = taskList + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + val sid = taskEnd.task.stageId + val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) + tasksActive -= taskEnd.taskInfo + val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = + taskEnd.reason match { + case e: ExceptionFailure => + stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1 + (Some(e), e.metrics) + case _ => + stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1 + (None, Option(taskEnd.taskMetrics)) + } + + stageToTime.getOrElseUpdate(sid, 0L) + val time = metrics.map(m => m.executorRunTime).getOrElse(0) + stageToTime(sid) += time + totalTime += time + + stageToShuffleRead.getOrElseUpdate(sid, 0L) + val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s => + s.remoteBytesRead).getOrElse(0L) + stageToShuffleRead(sid) += shuffleRead + totalShuffleRead += shuffleRead + + stageToShuffleWrite.getOrElseUpdate(sid, 0L) + val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s => + s.shuffleBytesWritten).getOrElse(0L) + stageToShuffleWrite(sid) += shuffleWrite + totalShuffleWrite += shuffleWrite + + val taskList = stageToTaskInfos.getOrElse( + sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) + taskList -= ((taskEnd.taskInfo, None, None)) + taskList += ((taskEnd.taskInfo, metrics, failureInfo)) + stageToTaskInfos(sid) = taskList + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { + jobEnd match { + case end: SparkListenerJobEnd => + end.jobResult match { + case JobFailed(ex, Some(stage)) => + activeStages -= stage + poolToActiveStages(stageToPool(stage)) -= stage + failedStages += stage + trimIfNecessary(failedStages) + case _ => + } + case _ => + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala new file mode 100644 index 0000000000..1bb7638bd9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import akka.util.Duration + +import java.text.SimpleDateFormat + +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.Handler + +import scala.Seq +import scala.collection.mutable.{HashSet, ListBuffer, HashMap, ArrayBuffer} + +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.{ExceptionFailure, SparkContext, Success, Utils} +import org.apache.spark.scheduler._ +import collection.mutable +import org.apache.spark.scheduler.cluster.SchedulingMode +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode + +/** Web UI showing progress status of all jobs in the given SparkContext. */ +private[spark] class JobProgressUI(val sc: SparkContext) { + private var _listener: Option[JobProgressListener] = None + def listener = _listener.get + val dateFmt = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + + private val indexPage = new IndexPage(this) + private val stagePage = new StagePage(this) + private val poolPage = new PoolPage(this) + + def start() { + _listener = Some(new JobProgressListener(sc)) + sc.addSparkListener(listener) + } + + def formatDuration(ms: Long) = Utils.msDurationToString(ms) + + def getHandlers = Seq[(String, Handler)]( + ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), + ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), + ("/stages", (request: HttpServletRequest) => indexPage.render(request)) + ) +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala new file mode 100644 index 0000000000..ce92b6932b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -0,0 +1,32 @@ +package org.apache.spark.ui.jobs + +import javax.servlet.http.HttpServletRequest + +import scala.xml.{NodeSeq, Node} +import scala.collection.mutable.HashSet + +import org.apache.spark.scheduler.Stage +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ + +/** Page showing specific pool details */ +private[spark] class PoolPage(parent: JobProgressUI) { + def listener = parent.listener + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val poolName = request.getParameter("poolname") + val poolToActiveStages = listener.poolToActiveStages + val activeStages = poolToActiveStages.get(poolName).toSeq.flatten + val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) + + val pool = listener.sc.getPoolForName(poolName).get + val poolTable = new PoolTable(Seq(pool), listener) + + val content =

Summary

++ poolTable.toNodeSeq() ++ +

{activeStages.size} Active Stages

++ activeStagesTable.toNodeSeq() + + headerSparkPage(content, parent.sc, "Fair Scheduler Pool: " + poolName, Stages) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala new file mode 100644 index 0000000000..f31465e59d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -0,0 +1,55 @@ +package org.apache.spark.ui.jobs + +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.xml.Node + +import org.apache.spark.scheduler.Stage +import org.apache.spark.scheduler.cluster.Schedulable + +/** Table showing list of pools */ +private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) { + + var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages + + def toNodeSeq(): Seq[Node] = { + listener.synchronized { + poolTable(poolRow, pools) + } + } + + private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node], + rows: Seq[Schedulable] + ): Seq[Node] = { + + + + + + + + + + + {rows.map(r => makeRow(r, poolToActiveStages))} + +
Pool NameMinimum SharePool WeightActive StagesRunning TasksSchedulingMode
+ } + + private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]]) + : Seq[Node] = { + val activeStages = poolToActiveStages.get(p.name) match { + case Some(stages) => stages.size + case None => 0 + } + + {p.name} + {p.minShare} + {p.weight} + {activeStages} + {p.runningTasks} + {p.schedulingMode} + + } +} + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala new file mode 100644 index 0000000000..2fe85bc0cf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import java.util.Date + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ +import org.apache.spark.util.Distribution +import org.apache.spark.{ExceptionFailure, Utils} +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.executor.TaskMetrics + +/** Page showing statistics and task list for a given stage */ +private[spark] class StagePage(parent: JobProgressUI) { + def listener = parent.listener + val dateFmt = parent.dateFmt + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val stageId = request.getParameter("id").toInt + val now = System.currentTimeMillis() + + if (!listener.stageToTaskInfos.contains(stageId)) { + val content = +
+

Summary Metrics

No tasks have started yet +

Tasks

No tasks have started yet +
+ return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages) + } + + val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime) + + val numCompleted = tasks.count(_._1.finished) + val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L) + val hasShuffleRead = shuffleReadBytes > 0 + val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L) + val hasShuffleWrite = shuffleWriteBytes > 0 + + var activeTime = 0L + listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) + + val summary = +
+
    +
  • + CPU time: + {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)} +
  • + {if (hasShuffleRead) +
  • + Shuffle read: + {Utils.bytesToString(shuffleReadBytes)} +
  • + } + {if (hasShuffleWrite) +
  • + Shuffle write: + {Utils.bytesToString(shuffleWriteBytes)} +
  • + } +
+
+ + val taskHeaders: Seq[String] = + Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++ + Seq("GC Time") ++ + {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ + {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++ + Seq("Errors") + + val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks) + + // Excludes tasks which failed and have incomplete metrics + val validTasks = tasks.filter(t => t._1.status == "SUCCESS" && (t._2.isDefined)) + + val summaryTable: Option[Seq[Node]] = + if (validTasks.size == 0) { + None + } + else { + val serviceTimes = validTasks.map{case (info, metrics, exception) => + metrics.get.executorRunTime.toDouble} + val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map( + ms => parent.formatDuration(ms.toLong)) + + def getQuantileCols(data: Seq[Double]) = + Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong)) + + val shuffleReadSizes = validTasks.map { + case(info, metrics, exception) => + metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble + } + val shuffleReadQuantiles = "Shuffle Read (Remote)" +: getQuantileCols(shuffleReadSizes) + + val shuffleWriteSizes = validTasks.map { + case(info, metrics, exception) => + metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble + } + val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes) + + val listings: Seq[Seq[String]] = Seq(serviceQuantiles, + if (hasShuffleRead) shuffleReadQuantiles else Nil, + if (hasShuffleWrite) shuffleWriteQuantiles else Nil) + + val quantileHeaders = Seq("Metric", "Min", "25th percentile", + "Median", "75th percentile", "Max") + def quantileRow(data: Seq[String]): Seq[Node] = {data.map(d => {d})} + Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) + } + + val content = + summary ++ +

Summary Metrics for {numCompleted} Completed Tasks

++ +
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++ +

Tasks

++ taskTable; + + headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) + } + } + + + def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean) + (taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = { + def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] = + trace.map(e => {e.toString}) + val (info, metrics, exception) = taskData + + val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis()) + else metrics.map(m => m.executorRunTime).getOrElse(1) + val formatDuration = if (info.status == "RUNNING") parent.formatDuration(duration) + else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") + val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) + + + {info.taskId} + {info.status} + {info.taskLocality} + {info.host} + {dateFmt.format(new Date(info.launchTime))} + + {formatDuration} + + + {if (gcTime > 0) parent.formatDuration(gcTime) else ""} + + {if (shuffleRead) { + {metrics.flatMap{m => m.shuffleReadMetrics}.map{s => + Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")} + }} + {if (shuffleWrite) { + {metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => + Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")} + }} + {exception.map(e => + + {e.className} ({e.description})
+ {fmtStackTrace(e.stackTrace)} +
).getOrElse("")} + + + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala new file mode 100644 index 0000000000..beb0574548 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -0,0 +1,107 @@ +package org.apache.spark.ui.jobs + +import java.util.Date + +import scala.xml.Node +import scala.collection.mutable.HashSet + +import org.apache.spark.Utils +import org.apache.spark.scheduler.cluster.{SchedulingMode, TaskInfo} +import org.apache.spark.scheduler.Stage + + +/** Page showing list of all ongoing and recently finished stages */ +private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) { + + val listener = parent.listener + val dateFmt = parent.dateFmt + val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR + + def toNodeSeq(): Seq[Node] = { + listener.synchronized { + stageTable(stageRow, stages) + } + } + + /** Special table which merges two header cells. */ + private def stageTable[T](makeRow: T => Seq[Node], rows: Seq[T]): Seq[Node] = { + + + + {if (isFairScheduler) {} else {}} + + + + + + + + + {rows.map(r => makeRow(r))} + +
Stage IdPool NameDescriptionSubmittedDurationTasks: Succeeded/TotalShuffle ReadShuffle Write
+ } + + private def makeProgressBar(started: Int, completed: Int, failed: String, total: Int): Seq[Node] = { + val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) + val startWidth = "width: %s%%".format((started.toDouble/total)*100) + +
+ + {completed}/{total} {failed} + +
+
+
+ } + + + private def stageRow(s: Stage): Seq[Node] = { + val submissionTime = s.submissionTime match { + case Some(t) => dateFmt.format(new Date(t)) + case None => "Unknown" + } + + val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match { + case 0 => "" + case b => Utils.bytesToString(b) + } + val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match { + case 0 => "" + case b => Utils.bytesToString(b) + } + + val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size + val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0) + val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match { + case f if f > 0 => "(%s failed)".format(f) + case _ => "" + } + val totalTasks = s.numPartitions + + val poolName = listener.stageToPool.get(s) + + val nameLink = {s.name} + val description = listener.stageToDescription.get(s) + .map(d =>
{d}
{nameLink}
).getOrElse(nameLink) + val finishTime = s.completionTime.getOrElse(System.currentTimeMillis()) + val duration = s.submissionTime.map(t => finishTime - t) + + + {s.id} + {if (isFairScheduler) { + {poolName.get}} + } + {description} + {submissionTime} + + {duration.map(d => parent.formatDuration(d)).getOrElse("Unknown")} + + + {makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)} + + {shuffleRead} + {shuffleWrite} + + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala new file mode 100644 index 0000000000..1d633d374a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import akka.util.Duration + +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.Handler + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.ui.JettyUtils._ + +/** Web UI showing storage status of all RDD's in the given SparkContext. */ +private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging { + implicit val timeout = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + + val indexPage = new IndexPage(this) + val rddPage = new RDDPage(this) + + def getHandlers = Seq[(String, Handler)]( + ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), + ("/storage", (request: HttpServletRequest) => indexPage.render(request)) + ) +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala new file mode 100644 index 0000000000..1eb4a7a85e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.storage.{RDDInfo, StorageUtils} +import org.apache.spark.Utils +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ + +/** Page showing list of RDD's currently stored in the cluster */ +private[spark] class IndexPage(parent: BlockManagerUI) { + val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val storageStatusList = sc.getExecutorStorageStatus + // Calculate macro-level statistics + + val rddHeaders = Seq( + "RDD Name", + "Storage Level", + "Cached Partitions", + "Fraction Cached", + "Size in Memory", + "Size on Disk") + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) + val content = listingTable(rddHeaders, rddRow, rdds) + + headerSparkPage(content, parent.sc, "Storage ", Storage) + } + + def rddRow(rdd: RDDInfo): Seq[Node] = { + + + + {rdd.name} + + + {rdd.storageLevel.description} + + {rdd.numCachedPartitions} + {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} + {Utils.bytesToString(rdd.memSize)} + {Utils.bytesToString(rdd.diskSize)} + + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala new file mode 100644 index 0000000000..37baf17f7a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.Utils +import org.apache.spark.storage.{StorageStatus, StorageUtils} +import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ + + +/** Page showing storage details for a given RDD */ +private[spark] class RDDPage(parent: BlockManagerUI) { + val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val id = request.getParameter("id") + val prefix = "rdd_" + id.toString + val storageStatusList = sc.getExecutorStorageStatus + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head + + val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage") + val workers = filteredStorageStatusList.map((prefix, _)) + val workerTable = listingTable(workerHeaders, workerRow, workers) + + val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk", + "Executors") + + val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1) + val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList) + val blocks = blockStatuses.map { + case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN"))) + } + val blockTable = listingTable(blockHeaders, blockRow, blocks) + + val content = +
+
+
    +
  • + Storage Level: + {rddInfo.storageLevel.description} +
  • +
  • + Cached Partitions: + {rddInfo.numCachedPartitions} +
  • +
  • + Total Partitions: + {rddInfo.numPartitions} +
  • +
  • + Memory Size: + {Utils.bytesToString(rddInfo.memSize)} +
  • +
  • + Disk Size: + {Utils.bytesToString(rddInfo.diskSize)} +
  • +
+
+
+ +
+
+

Data Distribution on {workers.size} Executors

+ {workerTable} +
+
+ +
+
+

{blocks.size} Partitions

+ {blockTable} +
+
; + + headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage) + } + + def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = { + val (id, block, locations) = row + + {id} + + {block.storageLevel.description} + + + {Utils.bytesToString(block.memSize)} + + + {Utils.bytesToString(block.diskSize)} + + + {locations.map(l => {l}
)} + + + } + + def workerRow(worker: (String, StorageStatus)): Seq[Node] = { + val (prefix, status) = worker + + {status.blockManagerId.host + ":" + status.blockManagerId.port} + + {Utils.bytesToString(status.memUsed(prefix))} + ({Utils.bytesToString(status.memRemaining)} Remaining) + + {Utils.bytesToString(status.diskUsed(prefix))} + + } +} diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala new file mode 100644 index 0000000000..d4c5065c3f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import akka.actor.{ActorSystem, ExtendedActorSystem} +import com.typesafe.config.ConfigFactory +import akka.util.duration._ +import akka.remote.RemoteActorRefProvider + + +/** + * Various utility classes for working with Akka. + */ +private[spark] object AkkaUtils { + + /** + * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the + * ActorSystem itself and its port (which is hard to get from Akka). + * + * Note: the `name` parameter is important, as even if a client sends a message to right + * host + port, if the system name is incorrect, Akka will drop the message. + */ + def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { + val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt + val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt + val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt + val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt + val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" + // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. + val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt + + val akkaConf = ConfigFactory.parseString(""" + akka.daemonic = on + akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] + akka.stdout-loglevel = "ERROR" + akka.actor.provider = "akka.remote.RemoteActorRefProvider" + akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" + akka.remote.netty.hostname = "%s" + akka.remote.netty.port = %d + akka.remote.netty.connection-timeout = %ds + akka.remote.netty.message-frame-size = %d MiB + akka.remote.netty.execution-pool-size = %d + akka.actor.default-dispatcher.throughput = %d + akka.remote.log-remote-lifecycle-events = %s + akka.remote.netty.write-timeout = %ds + """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, + lifecycleEvents, akkaWriteTimeout)) + + val actorSystem = ActorSystem(name, akkaConf) + + // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a + // hack because Akka doesn't let you figure out the port through the public API yet. + val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider + val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get + return (actorSystem, boundPort) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala new file mode 100644 index 0000000000..0b51c23f7b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.Serializable +import java.util.{PriorityQueue => JPriorityQueue} +import scala.collection.generic.Growable +import scala.collection.JavaConverters._ + +/** + * Bounded priority queue. This class wraps the original PriorityQueue + * class and modifies it such that only the top K elements are retained. + * The top K elements are defined by an implicit Ordering[A]. + */ +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) + extends Iterable[A] with Growable[A] with Serializable { + + private val underlying = new JPriorityQueue[A](maxSize, ord) + + override def iterator: Iterator[A] = underlying.iterator.asScala + + override def ++=(xs: TraversableOnce[A]): this.type = { + xs.foreach { this += _ } + this + } + + override def +=(elem: A): this.type = { + if (size < maxSize) underlying.offer(elem) + else maybeReplaceLowest(elem) + this + } + + override def +=(elem1: A, elem2: A, elems: A*): this.type = { + this += elem1 += elem2 ++= elems + } + + override def clear() { underlying.clear() } + + private def maybeReplaceLowest(a: A): Boolean = { + val head = underlying.peek() + if (head != null && ord.gt(a, head)) { + underlying.poll() + underlying.offer(a) + } else false + } +} + diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala new file mode 100644 index 0000000000..e214d2a519 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.InputStream +import java.nio.ByteBuffer +import org.apache.spark.storage.BlockManager + +/** + * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose() + * at the end of the stream (e.g. to close a memory-mapped file). + */ +private[spark] +class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) + extends InputStream { + + override def read(): Int = { + if (buffer == null || buffer.remaining() == 0) { + cleanUp() + -1 + } else { + buffer.get() & 0xFF + } + } + + override def read(dest: Array[Byte]): Int = { + read(dest, 0, dest.length) + } + + override def read(dest: Array[Byte], offset: Int, length: Int): Int = { + if (buffer == null || buffer.remaining() == 0) { + cleanUp() + -1 + } else { + val amountToGet = math.min(buffer.remaining(), length) + buffer.get(dest, offset, amountToGet) + amountToGet + } + } + + override def skip(bytes: Long): Long = { + if (buffer != null) { + val amountToSkip = math.min(bytes, buffer.remaining).toInt + buffer.position(buffer.position + amountToSkip) + if (buffer.remaining() == 0) { + cleanUp() + } + amountToSkip + } else { + 0L + } + } + + /** + * Clean up the buffer, and potentially dispose of it using BlockManager.dispose(). + */ + private def cleanUp() { + if (buffer != null) { + if (dispose) { + BlockManager.dispose(buffer) + } + buffer = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Clock.scala b/core/src/main/scala/org/apache/spark/util/Clock.scala new file mode 100644 index 0000000000..97c2b45aab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Clock.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * An interface to represent clocks, so that they can be mocked out in unit tests. + */ +private[spark] trait Clock { + def getTime(): Long +} + +private[spark] object SystemClock extends Clock { + def getTime(): Long = System.currentTimeMillis() +} diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala new file mode 100644 index 0000000000..dc15a38b29 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements + */ +abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{ + def next = sub.next + def hasNext = { + val r = sub.hasNext + if (!r) { + completion + } + r + } + + def completion() +} + +object CompletionIterator { + def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = { + new CompletionIterator[A,I](sub) { + def completion() = completionFunction + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala new file mode 100644 index 0000000000..33bf3562fe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.PrintStream + +/** + * Util for getting some stats from a small sample of numeric values, with some handy summary functions. + * + * Entirely in memory, not intended as a good way to compute stats over large data sets. + * + * Assumes you are giving it a non-empty set of data + */ +class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) { + require(startIdx < endIdx) + def this(data: Traversable[Double]) = this(data.toArray, 0, data.size) + java.util.Arrays.sort(data, startIdx, endIdx) + val length = endIdx - startIdx + + val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0) + + /** + * Get the value of the distribution at the given probabilities. Probabilities should be + * given from 0 to 1 + * @param probabilities + */ + def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) = { + probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))} + } + + private def closestIndex(p: Double) = { + math.min((p * length).toInt + startIdx, endIdx - 1) + } + + def showQuantiles(out: PrintStream = System.out) = { + out.println("min\t25%\t50%\t75%\tmax") + getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")} + out.println + } + + def statCounter = StatCounter(data.slice(startIdx, endIdx)) + + /** + * print a summary of this distribution to the given PrintStream. + * @param out + */ + def summary(out: PrintStream = System.out) { + out.println(statCounter) + showQuantiles(out) + } +} + +object Distribution { + + def apply(data: Traversable[Double]): Option[Distribution] = { + if (data.size > 0) + Some(new Distribution(data)) + else + None + } + + def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) { + out.println("min\t25%\t50%\t75%\tmax") + quantiles.foreach{q => out.print(q + "\t")} + out.println + } +} diff --git a/core/src/main/scala/org/apache/spark/util/IdGenerator.scala b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala new file mode 100644 index 0000000000..17e55f7996 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.atomic.AtomicInteger + +/** + * A util used to get a unique generation ID. This is a wrapper around Java's + * AtomicInteger. An example usage is in BlockManager, where each BlockManager + * instance would start an Akka actor and we use this utility to assign the Akka + * actors unique names. + */ +private[spark] class IdGenerator { + private var id = new AtomicInteger + def next: Int = id.incrementAndGet +} diff --git a/core/src/main/scala/org/apache/spark/util/IntParam.scala b/core/src/main/scala/org/apache/spark/util/IntParam.scala new file mode 100644 index 0000000000..626bb49eea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/IntParam.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * An extractor object for parsing strings into integers. + */ +private[spark] object IntParam { + def unapply(str: String): Option[Int] = { + try { + Some(str.toInt) + } catch { + case e: NumberFormatException => None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MemoryParam.scala b/core/src/main/scala/org/apache/spark/util/MemoryParam.scala new file mode 100644 index 0000000000..0ee6707826 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MemoryParam.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.Utils + +/** + * An extractor object for parsing JVM memory strings, such as "10g", into an Int representing + * the number of megabytes. Supports the same formats as Utils.memoryStringToMb. + */ +private[spark] object MemoryParam { + def unapply(str: String): Option[Int] = { + try { + Some(Utils.memoryStringToMb(str)) + } catch { + case e: NumberFormatException => None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala new file mode 100644 index 0000000000..a430a75451 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import org.apache.spark.Logging + + +/** + * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) + */ +class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { + private val delaySeconds = MetadataCleaner.getDelaySeconds + private val periodSeconds = math.max(10, delaySeconds / 10) + private val timer = new Timer(name + " cleanup timer", true) + + private val task = new TimerTask { + override def run() { + try { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + + if (delaySeconds > 0) { + logDebug( + "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + + "and period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + } + + def cancel() { + timer.cancel() + } +} + + +object MetadataCleaner { + def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt + def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) } +} + diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala new file mode 100644 index 0000000000..34f1f6606f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + + +/** + * A tuple of 2 elements. This can be used as an alternative to Scala's Tuple2 when we want to + * minimize object allocation. + * + * @param _1 Element 1 of this MutablePair + * @param _2 Element 2 of this MutablePair + */ +case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1, + @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] + (var _1: T1, var _2: T2) + extends Product2[T1, T2] +{ + override def toString = "(" + _1 + "," + _2 + ")" + + override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] +} diff --git a/core/src/main/scala/org/apache/spark/util/NextIterator.scala b/core/src/main/scala/org/apache/spark/util/NextIterator.scala new file mode 100644 index 0000000000..8266e5e495 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/NextIterator.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** Provides a basic/boilerplate Iterator implementation. */ +private[spark] abstract class NextIterator[U] extends Iterator[U] { + + private var gotNext = false + private var nextValue: U = _ + private var closed = false + protected var finished = false + + /** + * Method for subclasses to implement to provide the next element. + * + * If no next element is available, the subclass should set `finished` + * to `true` and may return any value (it will be ignored). + * + * This convention is required because `null` may be a valid value, + * and using `Option` seems like it might create unnecessary Some/None + * instances, given some iterators might be called in a tight loop. + * + * @return U, or set 'finished' when done + */ + protected def getNext(): U + + /** + * Method for subclasses to implement when all elements have been successfully + * iterated, and the iteration is done. + * + * Note: `NextIterator` cannot guarantee that `close` will be + * called because it has no control over what happens when an exception + * happens in the user code that is calling hasNext/next. + * + * Ideally you should have another try/catch, as in HadoopRDD, that + * ensures any resources are closed should iteration fail. + */ + protected def close() + + /** + * Calls the subclass-defined close method, but only once. + * + * Usually calling `close` multiple times should be fine, but historically + * there have been issues with some InputFormats throwing exceptions. + */ + def closeIfNeeded() { + if (!closed) { + close() + closed = true + } + } + + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + nextValue = getNext() + if (finished) { + closeIfNeeded() + } + gotNext = true + } + } + !finished + } + + override def next(): U = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } +} diff --git a/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala new file mode 100644 index 0000000000..47e1b45004 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.annotation.tailrec + +import java.io.OutputStream +import java.util.concurrent.TimeUnit._ + +class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + val CHUNK_SIZE = 8192 + var lastSyncTime = System.nanoTime + var bytesWrittenSinceSync: Long = 0 + + override def write(b: Int) { + waitToWrite(1) + out.write(b) + } + + override def write(bytes: Array[Byte]) { + write(bytes, 0, bytes.length) + } + + @tailrec + override final def write(bytes: Array[Byte], offset: Int, length: Int) { + val writeSize = math.min(length - offset, CHUNK_SIZE) + if (writeSize > 0) { + waitToWrite(writeSize) + out.write(bytes, offset, writeSize) + write(bytes, offset + writeSize, length) + } + } + + override def flush() { + out.flush() + } + + override def close() { + out.close() + } + + @tailrec + private def waitToWrite(numBytes: Int) { + val now = System.nanoTime + val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) + val rate = bytesWrittenSinceSync.toDouble / elapsedSecs + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + SYNC_INTERVAL) { + // Sync interval has passed; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + } else { + // Calculate how much time we should sleep to bring ourselves to the desired rate. + // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) + val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) + if (sleepTime > 0) Thread.sleep(sleepTime) + waitToWrite(numBytes) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala new file mode 100644 index 0000000000..f2b1ad7d0e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.nio.ByteBuffer +import java.io.{IOException, ObjectOutputStream, EOFException, ObjectInputStream} +import java.nio.channels.Channels + +/** + * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make + * it easier to pass ByteBuffers in case class messages. + */ +private[spark] +class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { + def value = buffer + + private def readObject(in: ObjectInputStream) { + val length = in.readInt() + buffer = ByteBuffer.allocate(length) + var amountRead = 0 + val channel = Channels.newChannel(in) + while (amountRead < length) { + val ret = channel.read(buffer) + if (ret == -1) { + throw new EOFException("End of file before fully reading buffer") + } + amountRead += ret + } + buffer.rewind() // Allow us to read it later + } + + private def writeObject(out: ObjectOutputStream) { + out.writeInt(buffer.limit()) + if (Channels.newChannel(out).write(buffer) != buffer.limit()) { + throw new IOException("Could not fully write buffer to output stream") + } + buffer.rewind() // Allow us to write it again later + } +} diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala new file mode 100644 index 0000000000..020d5edba9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * A class for tracking the statistics of a set of numbers (count, mean and variance) in a + * numerically robust way. Includes support for merging two StatCounters. Based on + * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Welford and Chan's algorithms for running variance]]. + * + * @constructor Initialize the StatCounter with the given values. + */ +class StatCounter(values: TraversableOnce[Double]) extends Serializable { + private var n: Long = 0 // Running count of our values + private var mu: Double = 0 // Running mean of our values + private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) + + merge(values) + + /** Initialize the StatCounter with no values. */ + def this() = this(Nil) + + /** Add a value into this StatCounter, updating the internal statistics. */ + def merge(value: Double): StatCounter = { + val delta = value - mu + n += 1 + mu += delta / n + m2 += delta * (value - mu) + this + } + + /** Add multiple values into this StatCounter, updating the internal statistics. */ + def merge(values: TraversableOnce[Double]): StatCounter = { + values.foreach(v => merge(v)) + this + } + + /** Merge another StatCounter into this one, adding up the internal statistics. */ + def merge(other: StatCounter): StatCounter = { + if (other == this) { + merge(other.copy()) // Avoid overwriting fields in a weird order + } else { + if (n == 0) { + mu = other.mu + m2 = other.m2 + n = other.n + } else if (other.n != 0) { + val delta = other.mu - mu + if (other.n * 10 < n) { + mu = mu + (delta * other.n) / (n + other.n) + } else if (n * 10 < other.n) { + mu = other.mu - (delta * n) / (n + other.n) + } else { + mu = (mu * n + other.mu * other.n) / (n + other.n) + } + m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) + n += other.n + } + this + } + } + + /** Clone this StatCounter */ + def copy(): StatCounter = { + val other = new StatCounter + other.n = n + other.mu = mu + other.m2 = m2 + other + } + + def count: Long = n + + def mean: Double = mu + + def sum: Double = n * mu + + /** Return the variance of the values. */ + def variance: Double = { + if (n == 0) + Double.NaN + else + m2 / n + } + + /** + * Return the sample variance, which corrects for bias in estimating the variance by dividing + * by N-1 instead of N. + */ + def sampleVariance: Double = { + if (n <= 1) + Double.NaN + else + m2 / (n - 1) + } + + /** Return the standard deviation of the values. */ + def stdev: Double = math.sqrt(variance) + + /** + * Return the sample standard deviation of the values, which corrects for bias in estimating the + * variance by dividing by N-1 instead of N. + */ + def sampleStdev: Double = math.sqrt(sampleVariance) + + override def toString: String = { + "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev) + } +} + +object StatCounter { + /** Build a StatCounter from a list of values. */ + def apply(values: TraversableOnce[Double]) = new StatCounter(values) + + /** Build a StatCounter from a list of values passed as variable-length arguments. */ + def apply(values: Double*) = new StatCounter(values) +} diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala new file mode 100644 index 0000000000..277de2f8a6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConversions +import scala.collection.mutable.Map +import scala.collection.immutable +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.Logging + +/** + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * time stamp along with each key-value pair. Key-value pairs that are older than a particular + * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in + * replacement of scala.collection.mutable.HashMap. + */ +class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { + val internalMap = new ConcurrentHashMap[A, (B, Long)]() + + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null) Some(value._1) else None + } + + def iterator: Iterator[(A, B)] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) + } + + override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.put(kv._1, (kv._2, currentTime)) + newMap + } + + override def - (key: A): Map[A, B] = { + val newMap = new TimeStampedHashMap[A, B] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.remove(key) + newMap + } + + override def += (kv: (A, B)): this.type = { + internalMap.put(kv._1, (kv._2, currentTime)) + this + } + + // Should we return previous value directly or as Option ? + def putIfAbsent(key: A, value: B): Option[B] = { + val prev = internalMap.putIfAbsent(key, (value, currentTime)) + if (prev != null) Some(prev._1) else None + } + + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) + } + + override def apply(key: A): B = { + val value = internalMap.get(key) + if (value == null) throw new NoSuchElementException() + value._1 + } + + override def filter(p: ((A, B)) => Boolean): Map[A, B] = { + JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) + } + + override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue._1) + f(kv) + } + } + + def toMap: immutable.Map[A, B] = iterator.toMap + + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue._2 < threshTime) { + logDebug("Removing key " + entry.getKey) + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() + +} diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala new file mode 100644 index 0000000000..26983138ff --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.mutable.Set +import scala.collection.JavaConversions +import java.util.concurrent.ConcurrentHashMap + + +class TimeStampedHashSet[A] extends Set[A] { + val internalMap = new ConcurrentHashMap[A, Long]() + + def contains(key: A): Boolean = { + internalMap.contains(key) + } + + def iterator: Iterator[A] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(_.getKey) + } + + override def + (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet += elem + newSet + } + + override def - (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet -= elem + newSet + } + + override def += (key: A): this.type = { + internalMap.put(key, currentTime) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def empty: Set[A] = new TimeStampedHashSet[A]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: (A) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + f(iterator.next.getKey) + } + } + + /** + * Removes old values that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() +} diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala new file mode 100644 index 0000000000..fe710c58ac --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +class Vector(val elements: Array[Double]) extends Serializable { + def length = elements.length + + def apply(index: Int) = elements(index) + + def + (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + return Vector(length, i => this(i) + other(i)) + } + + def add(other: Vector) = this + other + + def - (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + return Vector(length, i => this(i) - other(i)) + } + + def subtract(other: Vector) = this - other + + def dot(other: Vector): Double = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + var i = 0 + while (i < length) { + ans += this(i) * other(i) + i += 1 + } + return ans + } + + /** + * return (this + plus) dot other, but without creating any intermediate storage + * @param plus + * @param other + * @return + */ + def plusDot(plus: Vector, other: Vector): Double = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + if (length != plus.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + var i = 0 + while (i < length) { + ans += (this(i) + plus(i)) * other(i) + i += 1 + } + return ans + } + + def += (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + var i = 0 + while (i < length) { + elements(i) += other(i) + i += 1 + } + this + } + + def addInPlace(other: Vector) = this +=other + + def * (scale: Double): Vector = Vector(length, i => this(i) * scale) + + def multiply (d: Double) = this * d + + def / (d: Double): Vector = this * (1 / d) + + def divide (d: Double) = this / d + + def unary_- = this * -1 + + def sum = elements.reduceLeft(_ + _) + + def squaredDist(other: Vector): Double = { + var ans = 0.0 + var i = 0 + while (i < length) { + ans += (this(i) - other(i)) * (this(i) - other(i)) + i += 1 + } + return ans + } + + def dist(other: Vector): Double = math.sqrt(squaredDist(other)) + + override def toString = elements.mkString("(", ", ", ")") +} + +object Vector { + def apply(elements: Array[Double]) = new Vector(elements) + + def apply(elements: Double*) = new Vector(elements.toArray) + + def apply(length: Int, initializer: Int => Double): Vector = { + val elements: Array[Double] = Array.tabulate(length)(initializer) + return new Vector(elements) + } + + def zeros(length: Int) = new Vector(new Array[Double](length)) + + def ones(length: Int) = Vector(length, _ => 1) + + class Multiplier(num: Double) { + def * (vec: Vector) = vec * num + } + + implicit def doubleToMultiplier(num: Double) = new Multiplier(num) + + implicit object VectorAccumParam extends org.apache.spark.AccumulatorParam[Vector] { + def addInPlace(t1: Vector, t2: Vector) = t1 + t2 + + def zero(initialValue: Vector) = Vector.zeros(initialValue.length) + } + +} diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala deleted file mode 100644 index 6ff92ce833..0000000000 --- a/core/src/main/scala/spark/Accumulators.scala +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ - -import scala.collection.mutable.Map -import scala.collection.generic.Growable - -/** - * A datatype that can be accumulated, i.e. has an commutative and associative "add" operation, - * but where the result type, `R`, may be different from the element type being added, `T`. - * - * You must define how to add data, and how to merge two of these together. For some datatypes, - * such as a counter, these might be the same operation. In that case, you can use the simpler - * [[spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are - * accumulating a set. You will add items to the set, and you will union two sets together. - * - * @param initialValue initial value of accumulator - * @param param helper object defining how to add elements of type `R` and `T` - * @tparam R the full accumulated data (result type) - * @tparam T partial data that can be added in - */ -class Accumulable[R, T] ( - @transient initialValue: R, - param: AccumulableParam[R, T]) - extends Serializable { - - val id = Accumulators.newId - @transient private var value_ = initialValue // Current value on master - val zero = param.zero(initialValue) // Zero value to be passed to workers - var deserialized = false - - Accumulators.register(this, true) - - /** - * Add more data to this accumulator / accumulable - * @param term the data to add - */ - def += (term: T) { value_ = param.addAccumulator(value_, term) } - - /** - * Add more data to this accumulator / accumulable - * @param term the data to add - */ - def add(term: T) { value_ = param.addAccumulator(value_, term) } - - /** - * Merge two accumulable objects together - * - * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other `R` that will get merged with this - */ - def ++= (term: R) { value_ = param.addInPlace(value_, term)} - - /** - * Merge two accumulable objects together - * - * Normally, a user will not want to use this version, but will instead call `add`. - * @param term the other `R` that will get merged with this - */ - def merge(term: R) { value_ = param.addInPlace(value_, term)} - - /** - * Access the accumulator's current value; only allowed on master. - */ - def value: R = { - if (!deserialized) { - value_ - } else { - throw new UnsupportedOperationException("Can't read accumulator value in task") - } - } - - /** - * Get the current value of this accumulator from within a task. - * - * This is NOT the global value of the accumulator. To get the global value after a - * completed operation on the dataset, call `value`. - * - * The typical use of this method is to directly mutate the local value, eg., to add - * an element to a Set. - */ - def localValue = value_ - - /** - * Set the accumulator's value; only allowed on master. - */ - def value_= (newValue: R) { - if (!deserialized) value_ = newValue - else throw new UnsupportedOperationException("Can't assign accumulator value in task") - } - - /** - * Set the accumulator's value; only allowed on master - */ - def setValue(newValue: R) { - this.value = newValue - } - - // Called by Java when deserializing an object - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - value_ = zero - deserialized = true - Accumulators.register(this, false) - } - - override def toString = value_.toString -} - -/** - * Helper object defining how to accumulate values of a particular type. An implicit - * AccumulableParam needs to be available when you create Accumulables of a specific type. - * - * @tparam R the full accumulated data (result type) - * @tparam T partial data that can be added in - */ -trait AccumulableParam[R, T] extends Serializable { - /** - * Add additional data to the accumulator value. Is allowed to modify and return `r` - * for efficiency (to avoid allocating objects). - * - * @param r the current value of the accumulator - * @param t the data to be added to the accumulator - * @return the new value of the accumulator - */ - def addAccumulator(r: R, t: T): R - - /** - * Merge two accumulated values together. Is allowed to modify and return the first value - * for efficiency (to avoid allocating objects). - * - * @param r1 one set of accumulated data - * @param r2 another set of accumulated data - * @return both data sets merged together - */ - def addInPlace(r1: R, r2: R): R - - /** - * Return the "zero" (identity) value for an accumulator type, given its initial value. For - * example, if R was a vector of N dimensions, this would return a vector of N zeroes. - */ - def zero(initialValue: R): R -} - -private[spark] -class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] - extends AccumulableParam[R,T] { - - def addAccumulator(growable: R, elem: T): R = { - growable += elem - growable - } - - def addInPlace(t1: R, t2: R): R = { - t1 ++= t2 - t1 - } - - def zero(initialValue: R): R = { - // We need to clone initialValue, but it's hard to specify that R should also be Cloneable. - // Instead we'll serialize it to a buffer and load it back. - val ser = (new spark.JavaSerializer).newInstance() - val copy = ser.deserialize[R](ser.serialize(initialValue)) - copy.clear() // In case it contained stuff - copy - } -} - -/** - * A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same - * as the types of elements being merged. - * - * @param initialValue initial value of accumulator - * @param param helper object defining how to add elements of type `T` - * @tparam T result type - */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T]) - extends Accumulable[T,T](initialValue, param) - -/** - * A simpler version of [[spark.AccumulableParam]] where the only datatype you can add in is the same type - * as the accumulated value. An implicit AccumulatorParam object needs to be available when you create - * Accumulators of a specific type. - * - * @tparam T type of value to accumulate - */ -trait AccumulatorParam[T] extends AccumulableParam[T, T] { - def addAccumulator(t1: T, t2: T): T = { - addInPlace(t1, t2) - } -} - -// TODO: The multi-thread support in accumulators is kind of lame; check -// if there's a more intuitive way of doing it right -private object Accumulators { - // TODO: Use soft references? => need to make readObject work properly then - val originals = Map[Long, Accumulable[_, _]]() - val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]() - var lastId: Long = 0 - - def newId: Long = synchronized { - lastId += 1 - return lastId - } - - def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { - if (original) { - originals(a.id) = a - } else { - val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map()) - accums(a.id) = a - } - } - - // Clear the local (non-original) accumulators for the current thread - def clear() { - synchronized { - localAccums.remove(Thread.currentThread) - } - } - - // Get the values of the local accumulators for the current thread (by ID) - def values: Map[Long, Any] = synchronized { - val ret = Map[Long, Any]() - for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) { - ret(id) = accum.localValue - } - return ret - } - - // Add values to the original accumulators with some given IDs - def add(values: Map[Long, Any]): Unit = synchronized { - for ((id, value) <- values) { - if (originals.contains(id)) { - originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value - } - } - } -} diff --git a/core/src/main/scala/spark/Aggregator.scala b/core/src/main/scala/spark/Aggregator.scala deleted file mode 100644 index 9af401986d..0000000000 --- a/core/src/main/scala/spark/Aggregator.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConversions._ - -/** A set of functions used to aggregate data. - * - * @param createCombiner function to create the initial value of the aggregation. - * @param mergeValue function to merge a new value into the aggregation result. - * @param mergeCombiners function to merge outputs from multiple mergeValue function. - */ -case class Aggregator[K, V, C] ( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) { - - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { - val combiners = new JHashMap[K, C] - for (kv <- iter) { - val oldC = combiners.get(kv._1) - if (oldC == null) { - combiners.put(kv._1, createCombiner(kv._2)) - } else { - combiners.put(kv._1, mergeValue(oldC, kv._2)) - } - } - combiners.iterator - } - - def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { - val combiners = new JHashMap[K, C] - iter.foreach { case(k, c) => - val oldC = combiners.get(k) - if (oldC == null) { - combiners.put(k, c) - } else { - combiners.put(k, mergeCombiners(oldC, c)) - } - } - combiners.iterator - } -} - diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala deleted file mode 100644 index 1ec95ed9b8..0000000000 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import spark.executor.{ShuffleReadMetrics, TaskMetrics} -import spark.serializer.Serializer -import spark.storage.BlockManagerId -import spark.util.CompletionIterator - - -private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - - override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) - : Iterator[T] = - { - - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager - - val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) - logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) - - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) - } - - def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = { - val blockId = blockPair._1 - val blockOption = blockPair._2 - blockOption match { - case Some(block) => { - block.asInstanceOf[Iterator[T]] - } - case None => { - val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r - blockId match { - case regex(shufId, mapId, _) => - val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null) - case _ => - throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block") - } - } - } - } - - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) - val itr = blockFetcherItr.flatMap(unpackBlock) - - CompletionIterator[T, Iterator[T]](itr, { - val shuffleMetrics = new ShuffleReadMetrics - shuffleMetrics.shuffleFinishTime = System.currentTimeMillis - shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime - shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime - shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead - shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks - shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks - shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - metrics.shuffleReadMetrics = Some(shuffleMetrics) - }) - } -} diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala deleted file mode 100644 index 81314805a9..0000000000 --- a/core/src/main/scala/spark/CacheManager.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.mutable.{ArrayBuffer, HashSet} -import spark.storage.{BlockManager, StorageLevel} - - -/** Spark class responsible for passing RDDs split contents to the BlockManager and making - sure a node doesn't load two copies of an RDD at once. - */ -private[spark] class CacheManager(blockManager: BlockManager) extends Logging { - private val loading = new HashSet[String] - - /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ - def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel) - : Iterator[T] = { - val key = "rdd_%d_%d".format(rdd.id, split.index) - logInfo("Cache key is " + key) - blockManager.get(key) match { - case Some(cachedValues) => - // Partition is in cache, so just return its values - logInfo("Found partition in cache!") - return cachedValues.asInstanceOf[Iterator[T]] - - case None => - // Mark the split as loading (unless someone else marks it first) - loading.synchronized { - if (loading.contains(key)) { - logInfo("Loading contains " + key + ", waiting...") - while (loading.contains(key)) { - try {loading.wait()} catch {case _ : Throwable =>} - } - logInfo("Loading no longer contains " + key + ", so returning cached result") - // See whether someone else has successfully loaded it. The main way this would fail - // is for the RDD-level cache eviction policy if someone else has loaded the same RDD - // partition but we didn't want to make space for it. However, that case is unlikely - // because it's unlikely that two threads would work on the same RDD partition. One - // downside of the current code is that threads wait serially if this does happen. - blockManager.get(key) match { - case Some(values) => - return values.asInstanceOf[Iterator[T]] - case None => - logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") - loading.add(key) - } - } else { - loading.add(key) - } - } - try { - // If we got here, we have to load the split - val elements = new ArrayBuffer[Any] - logInfo("Computing partition " + split) - elements ++= rdd.computeOrReadCheckpoint(split, context) - // Try to put this block in the blockManager - blockManager.put(key, elements, storageLevel, true) - return elements.iterator.asInstanceOf[Iterator[T]] - } finally { - loading.synchronized { - loading.remove(key) - loading.notifyAll() - } - } - } - } -} diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala deleted file mode 100644 index 8b39241095..0000000000 --- a/core/src/main/scala/spark/ClosureCleaner.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.lang.reflect.Field - -import scala.collection.mutable.Map -import scala.collection.mutable.Set - -import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.objectweb.asm.Opcodes._ -import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} - -private[spark] object ClosureCleaner extends Logging { - // Get an ASM class reader for a given class from the JAR that loaded it - private def getClassReader(cls: Class[_]): ClassReader = { - // Copy data over, before delegating to ClassReader - else we can run out of open file handles. - val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" - val resourceStream = cls.getResourceAsStream(className) - // todo: Fixme - continuing with earlier behavior ... - if (resourceStream == null) return new ClassReader(resourceStream) - - val baos = new ByteArrayOutputStream(128) - Utils.copyStream(resourceStream, baos, true) - new ClassReader(new ByteArrayInputStream(baos.toByteArray)) - } - - // Check whether a class represents a Scala closure - private def isClosure(cls: Class[_]): Boolean = { - cls.getName.contains("$anonfun$") - } - - // Get a list of the classes of the outer objects of a given closure object, obj; - // the outer objects are defined as any closures that obj is nested within, plus - // possibly the class that the outermost closure is in, if any. We stop searching - // for outer objects beyond that because cloning the user's object is probably - // not a good idea (whereas we can clone closure objects just fine since we - // understand how all their fields are used). - private def getOuterClasses(obj: AnyRef): List[Class[_]] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - if (isClosure(f.getType)) { - return f.getType :: getOuterClasses(f.get(obj)) - } else { - return f.getType :: Nil // Stop at the first $outer that is not a closure - } - } - return Nil - } - - // Get a list of the outer objects for a given closure object. - private def getOuterObjects(obj: AnyRef): List[AnyRef] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - if (isClosure(f.getType)) { - return f.get(obj) :: getOuterObjects(f.get(obj)) - } else { - return f.get(obj) :: Nil // Stop at the first $outer that is not a closure - } - } - return Nil - } - - private def getInnerClasses(obj: AnyRef): List[Class[_]] = { - val seen = Set[Class[_]](obj.getClass) - var stack = List[Class[_]](obj.getClass) - while (!stack.isEmpty) { - val cr = getClassReader(stack.head) - stack = stack.tail - val set = Set[Class[_]]() - cr.accept(new InnerClosureFinder(set), 0) - for (cls <- set -- seen) { - seen += cls - stack = cls :: stack - } - } - return (seen - obj.getClass).toList - } - - private def createNullValue(cls: Class[_]): AnyRef = { - if (cls.isPrimitive) { - new java.lang.Byte(0: Byte) // Should be convertible to any primitive type - } else { - null - } - } - - def clean(func: AnyRef) { - // TODO: cache outerClasses / innerClasses / accessedFields - val outerClasses = getOuterClasses(func) - val innerClasses = getInnerClasses(func) - val outerObjects = getOuterObjects(func) - - val accessedFields = Map[Class[_], Set[String]]() - for (cls <- outerClasses) - accessedFields(cls) = Set[String]() - for (cls <- func.getClass :: innerClasses) - getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) - //logInfo("accessedFields: " + accessedFields) - - val inInterpreter = { - try { - val interpClass = Class.forName("spark.repl.Main") - interpClass.getMethod("interp").invoke(null) != null - } catch { - case _: ClassNotFoundException => true - } - } - - var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse - var outer: AnyRef = null - if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - outer = outerPairs.head._2 - outerPairs = outerPairs.tail - } - // Clone the closure objects themselves, nulling out any fields that are not - // used in the closure we're working on or any of its inner closures. - for ((cls, obj) <- outerPairs) { - outer = instantiateClass(cls, outer, inInterpreter) - for (fieldName <- accessedFields(cls)) { - val field = cls.getDeclaredField(fieldName) - field.setAccessible(true) - val value = field.get(obj) - //logInfo("1: Setting " + fieldName + " on " + cls + " to " + value); - field.set(outer, value) - } - } - - if (outer != null) { - //logInfo("2: Setting $outer on " + func.getClass + " to " + outer); - val field = func.getClass.getDeclaredField("$outer") - field.setAccessible(true) - field.set(func, outer) - } - } - - private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { - //logInfo("Creating a " + cls + " with outer = " + outer) - if (!inInterpreter) { - // This is a bona fide closure class, whose constructor has no effects - // other than to set its fields, so use its constructor - val cons = cls.getConstructors()(0) - val params = cons.getParameterTypes.map(createNullValue).toArray - if (outer != null) - params(0) = outer // First param is always outer object - return cons.newInstance(params: _*).asInstanceOf[AnyRef] - } else { - // Use reflection to instantiate object without calling constructor - val rf = sun.reflect.ReflectionFactory.getReflectionFactory() - val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() - val newCtor = rf.newConstructorForSerialization(cls, parentCtor) - val obj = newCtor.newInstance().asInstanceOf[AnyRef] - if (outer != null) { - //logInfo("3: Setting $outer on " + cls + " to " + outer); - val field = cls.getDeclaredField("$outer") - field.setAccessible(true) - field.set(obj, outer) - } - return obj - } - } -} - -private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { - return new MethodVisitor(ASM4) { - override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { - if (op == GETFIELD) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name - } - } - } - - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { - // Check for calls a getter method for a variable in an interpreter wrapper object. - // This means that the corresponding field will be accessed, so we should save it. - if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name - } - } - } - } - } -} - -private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { - var myName: String = null - - override def visit(version: Int, access: Int, name: String, sig: String, - superName: String, interfaces: Array[String]) { - myName = name - } - - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { - return new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { - val argTypes = Type.getArgumentTypes(desc) - if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 - && argTypes(0).toString.startsWith("L") // is it an object? - && argTypes(0).getInternalName == myName) - output += Class.forName( - owner.replace('/', '.'), - false, - Thread.currentThread.getContextClassLoader) - } - } - } -} diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala deleted file mode 100644 index d5a9606570..0000000000 --- a/core/src/main/scala/spark/Dependency.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -/** - * Base class for dependencies. - */ -abstract class Dependency[T](val rdd: RDD[T]) extends Serializable - - -/** - * Base class for dependencies where each partition of the parent RDD is used by at most one - * partition of the child RDD. Narrow dependencies allow for pipelined execution. - */ -abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { - /** - * Get the parent partitions for a child partition. - * @param partitionId a partition of the child RDD - * @return the partitions of the parent RDD that the child partition depends upon - */ - def getParents(partitionId: Int): Seq[Int] -} - - -/** - * Represents a dependency on the output of a shuffle stage. - * @param rdd the parent RDD - * @param partitioner partitioner used to partition the shuffle output - * @param serializerClass class name of the serializer to use - */ -class ShuffleDependency[K, V]( - @transient rdd: RDD[_ <: Product2[K, V]], - val partitioner: Partitioner, - val serializerClass: String = null) - extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { - - val shuffleId: Int = rdd.context.newShuffleId() -} - - -/** - * Represents a one-to-one dependency between partitions of the parent and child RDDs. - */ -class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = List(partitionId) -} - - -/** - * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs. - * @param rdd the parent RDD - * @param inStart the start of the range in the parent RDD - * @param outStart the start of the range in the child RDD - * @param length the length of the range - */ -class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) - extends NarrowDependency[T](rdd) { - - override def getParents(partitionId: Int) = { - if (partitionId >= outStart && partitionId < outStart + length) { - List(partitionId - outStart + inStart) - } else { - Nil - } - } -} diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala deleted file mode 100644 index 104168e61c..0000000000 --- a/core/src/main/scala/spark/DoubleRDDFunctions.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import spark.partial.BoundedDouble -import spark.partial.MeanEvaluator -import spark.partial.PartialResult -import spark.partial.SumEvaluator -import spark.util.StatCounter - -/** - * Extra functions available on RDDs of Doubles through an implicit conversion. - * Import `spark.SparkContext._` at the top of your program to use these functions. - */ -class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { - /** Add up the elements in this RDD. */ - def sum(): Double = { - self.reduce(_ + _) - } - - /** - * Return a [[spark.util.StatCounter]] object that captures the mean, variance and count - * of the RDD's elements in one operation. - */ - def stats(): StatCounter = { - self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) - } - - /** Compute the mean of this RDD's elements. */ - def mean(): Double = stats().mean - - /** Compute the variance of this RDD's elements. */ - def variance(): Double = stats().variance - - /** Compute the standard deviation of this RDD's elements. */ - def stdev(): Double = stats().stdev - - /** - * Compute the sample standard deviation of this RDD's elements (which corrects for bias in - * estimating the standard deviation by dividing by N-1 instead of N). - */ - def sampleStdev(): Double = stats().sampleStdev - - /** - * Compute the sample variance of this RDD's elements (which corrects for bias in - * estimating the variance by dividing by N-1 instead of N). - */ - def sampleVariance(): Double = stats().sampleVariance - - /** (Experimental) Approximate operation to return the mean within a timeout. */ - def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.partitions.size, confidence) - self.context.runApproximateJob(self, processPartition, evaluator, timeout) - } - - /** (Experimental) Approximate operation to return the sum within a timeout. */ - def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.partitions.size, confidence) - self.context.runApproximateJob(self, processPartition, evaluator, timeout) - } -} diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala deleted file mode 100644 index a2dae6cae9..0000000000 --- a/core/src/main/scala/spark/FetchFailedException.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import spark.storage.BlockManagerId - -private[spark] class FetchFailedException( - taskEndReason: TaskEndReason, - message: String, - cause: Throwable) - extends Exception { - - def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) = - this(FetchFailed(bmAddress, shuffleId, mapId, reduceId), - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId), - cause) - - def this (shuffleId: Int, reduceId: Int, cause: Throwable) = - this(FetchFailed(null, shuffleId, -1, reduceId), - "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause) - - override def getMessage(): String = message - - - override def getCause(): Throwable = cause - - def toTaskEndReason: TaskEndReason = taskEndReason - -} diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala deleted file mode 100644 index a13a7a2859..0000000000 --- a/core/src/main/scala/spark/HttpFileServer.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.{File} -import com.google.common.io.Files - -private[spark] class HttpFileServer extends Logging { - - var baseDir : File = null - var fileDir : File = null - var jarDir : File = null - var httpServer : HttpServer = null - var serverUri : String = null - - def initialize() { - baseDir = Utils.createTempDir() - fileDir = new File(baseDir, "files") - jarDir = new File(baseDir, "jars") - fileDir.mkdir() - jarDir.mkdir() - logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir) - httpServer.start() - serverUri = httpServer.uri - } - - def stop() { - httpServer.stop() - } - - def addFile(file: File) : String = { - addFileToDir(file, fileDir) - return serverUri + "/files/" + file.getName - } - - def addJar(file: File) : String = { - addFileToDir(file, jarDir) - return serverUri + "/jars/" + file.getName - } - - def addFileToDir(file: File, dir: File) : String = { - Files.copy(file, new File(dir, file.getName)) - return dir + "/" + file.getName - } - -} diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala deleted file mode 100644 index c9dffbc631..0000000000 --- a/core/src/main/scala/spark/HttpServer.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.File -import java.net.InetAddress - -import org.eclipse.jetty.server.Server -import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.handler.DefaultHandler -import org.eclipse.jetty.server.handler.HandlerList -import org.eclipse.jetty.server.handler.ResourceHandler -import org.eclipse.jetty.util.thread.QueuedThreadPool - -/** - * Exception type thrown by HttpServer when it is in the wrong state for an operation. - */ -private[spark] class ServerStateException(message: String) extends Exception(message) - -/** - * An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext - * as well as classes created by the interpreter when the user types in code. This is just a wrapper - * around a Jetty server. - */ -private[spark] class HttpServer(resourceBase: File) extends Logging { - private var server: Server = null - private var port: Int = -1 - - def start() { - if (server != null) { - throw new ServerStateException("Server is already started") - } else { - server = new Server() - val connector = new SocketConnector - connector.setMaxIdleTime(60*1000) - connector.setSoLingerTime(-1) - connector.setPort(0) - server.addConnector(connector) - - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - val resHandler = new ResourceHandler - resHandler.setResourceBase(resourceBase.getAbsolutePath) - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) - server.setHandler(handlerList) - server.start() - port = server.getConnectors()(0).getLocalPort() - } - } - - def stop() { - if (server == null) { - throw new ServerStateException("Server is already stopped") - } else { - server.stop() - port = -1 - server = null - } - } - - /** - * Get the URI of this HTTP server (http://host:port) - */ - def uri: String = { - if (server == null) { - throw new ServerStateException("Server is not started") - } else { - return "http://" + Utils.localIpAddress + ":" + port - } - } -} diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala deleted file mode 100644 index 04c5f44e6b..0000000000 --- a/core/src/main/scala/spark/JavaSerializer.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ -import java.nio.ByteBuffer - -import serializer.{Serializer, SerializerInstance, DeserializationStream, SerializationStream} -import spark.util.ByteBufferInputStream - -private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream { - val objOut = new ObjectOutputStream(out) - def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this } - def flush() { objOut.flush() } - def close() { objOut.close() } -} - -private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) -extends DeserializationStream { - val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } - - def readObject[T](): T = objIn.readObject().asInstanceOf[T] - def close() { objIn.close() } -} - -private[spark] class JavaSerializerInstance extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() - val out = serializeStream(bos) - out.writeObject(t) - out.close() - ByteBuffer.wrap(bos.toByteArray) - } - - def deserialize[T](bytes: ByteBuffer): T = { - val bis = new ByteBufferInputStream(bytes) - val in = deserializeStream(bis) - in.readObject().asInstanceOf[T] - } - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - val bis = new ByteBufferInputStream(bytes) - val in = deserializeStream(bis, loader) - in.readObject().asInstanceOf[T] - } - - def serializeStream(s: OutputStream): SerializationStream = { - new JavaSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader) - } - - def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { - new JavaDeserializationStream(s, loader) - } -} - -/** - * A Spark serializer that uses Java's built-in serialization. - */ -class JavaSerializer extends Serializer { - def newInstance(): SerializerInstance = new JavaSerializerInstance -} diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala deleted file mode 100644 index eeb2993d8a..0000000000 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ -import java.nio.ByteBuffer -import com.esotericsoftware.kryo.{Kryo, KryoException} -import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} -import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} -import com.twitter.chill.ScalaKryoInstantiator -import serializer.{SerializerInstance, DeserializationStream, SerializationStream} -import spark.broadcast._ -import spark.storage._ - -private[spark] -class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - val output = new KryoOutput(outStream) - - def writeObject[T](t: T): SerializationStream = { - kryo.writeClassAndObject(output, t) - this - } - - def flush() { output.flush() } - def close() { output.close() } -} - -private[spark] -class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - val input = new KryoInput(inStream) - - def readObject[T](): T = { - try { - kryo.readClassAndObject(input).asInstanceOf[T] - } catch { - // DeserializationStream uses the EOF exception to indicate stopping condition. - case _: KryoException => throw new EOFException - } - } - - def close() { - // Kryo's Input automatically closes the input stream it is using. - input.close() - } -} - -private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.newKryo() - val output = ks.newKryoOutput() - val input = ks.newKryoInput() - - def serialize[T](t: T): ByteBuffer = { - output.clear() - kryo.writeClassAndObject(output, t) - ByteBuffer.wrap(output.toBytes) - } - - def deserialize[T](bytes: ByteBuffer): T = { - input.setBuffer(bytes.array) - kryo.readClassAndObject(input).asInstanceOf[T] - } - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - val oldClassLoader = kryo.getClassLoader - kryo.setClassLoader(loader) - input.setBuffer(bytes.array) - val obj = kryo.readClassAndObject(input).asInstanceOf[T] - kryo.setClassLoader(oldClassLoader) - obj - } - - def serializeStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(kryo, s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(kryo, s) - } -} - -/** - * Interface implemented by clients to register their classes with Kryo when using Kryo - * serialization. - */ -trait KryoRegistrator { - def registerClasses(kryo: Kryo) -} - -/** - * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. - */ -class KryoSerializer extends spark.serializer.Serializer with Logging { - private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 - - def newKryoOutput() = new KryoOutput(bufferSize) - - def newKryoInput() = new KryoInput(bufferSize) - - def newKryo(): Kryo = { - val instantiator = new ScalaKryoInstantiator - val kryo = instantiator.newKryo() - val classLoader = Thread.currentThread.getContextClassLoader - - // Register some commonly used classes - val toRegister: Seq[AnyRef] = Seq( - ByteBuffer.allocate(1), - StorageLevel.MEMORY_ONLY, - PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), - GotBlock("1", ByteBuffer.allocate(1)), - GetBlock("1") - ) - - for (obj <- toRegister) kryo.register(obj.getClass) - - // Allow sending SerializableWritable - kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) - kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) - - // Allow the user to register their own classes by setting spark.kryo.registrator - try { - Option(System.getProperty("spark.kryo.registrator")).foreach { regCls => - logDebug("Running user registrator: " + regCls) - val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator] - reg.registerClasses(kryo) - } - } catch { - case _: Exception => println("Failed to register spark.kryo.registrator") - } - - kryo.setClassLoader(classLoader) - - // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops - kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean) - - kryo - } - - def newInstance(): SerializerInstance = { - new KryoSerializerInstance(this) - } -} \ No newline at end of file diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala deleted file mode 100644 index 79b0362830..0000000000 --- a/core/src/main/scala/spark/Logging.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -/** - * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows - * logging messages at different levels using methods that only evaluate parameters lazily if the - * log level is enabled. - */ -trait Logging { - // Make the log field transient so that objects with Logging can - // be serialized and used on another machine - @transient private var log_ : Logger = null - - // Method to get or create the logger for this object - protected def log: Logger = { - if (log_ == null) { - var className = this.getClass.getName - // Ignore trailing $'s in the class names for Scala objects - if (className.endsWith("$")) { - className = className.substring(0, className.length - 1) - } - log_ = LoggerFactory.getLogger(className) - } - return log_ - } - - // Log methods that take only a String - protected def logInfo(msg: => String) { - if (log.isInfoEnabled) log.info(msg) - } - - protected def logDebug(msg: => String) { - if (log.isDebugEnabled) log.debug(msg) - } - - protected def logTrace(msg: => String) { - if (log.isTraceEnabled) log.trace(msg) - } - - protected def logWarning(msg: => String) { - if (log.isWarnEnabled) log.warn(msg) - } - - protected def logError(msg: => String) { - if (log.isErrorEnabled) log.error(msg) - } - - // Log methods that take Throwables (Exceptions/Errors) too - protected def logInfo(msg: => String, throwable: Throwable) { - if (log.isInfoEnabled) log.info(msg, throwable) - } - - protected def logDebug(msg: => String, throwable: Throwable) { - if (log.isDebugEnabled) log.debug(msg, throwable) - } - - protected def logTrace(msg: => String, throwable: Throwable) { - if (log.isTraceEnabled) log.trace(msg, throwable) - } - - protected def logWarning(msg: => String, throwable: Throwable) { - if (log.isWarnEnabled) log.warn(msg, throwable) - } - - protected def logError(msg: => String, throwable: Throwable) { - if (log.isErrorEnabled) log.error(msg, throwable) - } - - protected def isTraceEnabled(): Boolean = { - log.isTraceEnabled - } - - // Method for ensuring that logging is initialized, to avoid having multiple - // threads do it concurrently (as SLF4J initialization is not thread safe). - protected def initLogging() { log } -} diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala deleted file mode 100644 index 0cd0341a72..0000000000 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration - - -import spark.scheduler.MapStatus -import spark.storage.BlockManagerId -import spark.util.{MetadataCleaner, TimeStampedHashMap} - - -private[spark] sealed trait MapOutputTrackerMessage -private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) - extends MapOutputTrackerMessage -private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage - -private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging { - def receive = { - case GetMapOutputStatuses(shuffleId: Int, requester: String) => - logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester) - sender ! tracker.getSerializedLocations(shuffleId) - - case StopMapOutputTracker => - logInfo("MapOutputTrackerActor stopped!") - sender ! true - context.stop(self) - } -} - -private[spark] class MapOutputTracker extends Logging { - - private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - - // Set to the MapOutputTrackerActor living on the driver - var trackerActor: ActorRef = _ - - private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] - - // Incremented every time a fetch fails so that client nodes know to clear - // their cache of map output locations if this happens. - private var epoch: Long = 0 - private val epochLock = new java.lang.Object - - // Cache a serialized version of the output statuses for each shuffle to send them out faster - var cacheEpoch = epoch - private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] - - val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) - - // Send a message to the trackerActor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askTracker(message: Any): Any = { - try { - val future = trackerActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with MapOutputTracker", e) - } - } - - // Send a one-way message to the trackerActor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askTracker(message) != true) { - throw new SparkException("Error reply received from MapOutputTracker") - } - } - - def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { - throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") - } - } - - def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - var array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } - } - - def registerMapOutputs( - shuffleId: Int, - statuses: Array[MapStatus], - changeEpoch: Boolean = false) { - mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) - if (changeEpoch) { - incrementEpoch() - } - } - - def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - var array = arrayOpt.get - array.synchronized { - if (array(mapId) != null && array(mapId).location == bmAddress) { - array(mapId) = null - } - } - incrementEpoch() - } else { - throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") - } - } - - // Remembers which map output locations are currently being fetched on a worker - private val fetching = new HashSet[Int] - - // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle - def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - var fetchedStatuses: Array[MapStatus] = null - fetching.synchronized { - if (fetching.contains(shuffleId)) { - // Someone else is fetching it; wait for them to be done - while (fetching.contains(shuffleId)) { - try { - fetching.wait() - } catch { - case e: InterruptedException => - } - } - } - - // Either while we waited the fetch happened successfully, or - // someone fetched it in between the get and the fetching.synchronized. - fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - // We have to do the fetch, get others to wait for us. - fetching += shuffleId - } - } - - if (fetchedStatuses == null) { - // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) - val hostPort = Utils.localHostPort() - // This try-finally prevents hangs due to timeouts: - try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] - fetchedStatuses = deserializeStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() - } - } - } - if (fetchedStatuses != null) { - fetchedStatuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) - } - } - else{ - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing all output locations for shuffle " + shuffleId)) - } - } else { - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) - } - } - } - - private def cleanup(cleanupTime: Long) { - mapStatuses.clearOldValues(cleanupTime) - cachedSerializedStatuses.clearOldValues(cleanupTime) - } - - def stop() { - communicate(StopMapOutputTracker) - mapStatuses.clear() - metadataCleaner.cancel() - trackerActor = null - } - - // Called on master to increment the epoch number - def incrementEpoch() { - epochLock.synchronized { - epoch += 1 - logDebug("Increasing epoch to " + epoch) - } - } - - // Called on master or workers to get current epoch number - def getEpoch: Long = { - epochLock.synchronized { - return epoch - } - } - - // Called on workers to update the epoch number, potentially clearing old outputs - // because of a fetch failure. (Each worker task calls this with the latest epoch - // number on the master at the time it was created.) - def updateEpoch(newEpoch: Long) { - epochLock.synchronized { - if (newEpoch > epoch) { - logInfo("Updating epoch to " + newEpoch + " and clearing cache") - // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] - mapStatuses.clear() - epoch = newEpoch - } - } - } - - def getSerializedLocations(shuffleId: Int): Array[Byte] = { - var statuses: Array[MapStatus] = null - var epochGotten: Long = -1 - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - return bytes - case None => - statuses = mapStatuses(shuffleId) - epochGotten = epoch - } - } - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "locs"; let's serialize and return that - val bytes = serializeStatuses(statuses) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes - } - } - return bytes - } - - // Serialize an array of map output locations into an efficient byte format so that we can send - // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will - // generally be pretty compressible because many map outputs will be on the same hostname. - private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { - val out = new ByteArrayOutputStream - val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - // Since statuses can be modified in parallel, sync on it - statuses.synchronized { - objOut.writeObject(statuses) - } - objOut.close() - out.toByteArray - } - - // Opposite of serializeStatuses. - def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { - val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - objIn.readObject(). - // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present - // comment this out - nulls could be due to missing location ? - asInstanceOf[Array[MapStatus]] // .filter( _ != null ) - } -} - -private[spark] object MapOutputTracker { - private val LOG_BASE = 1.1 - - // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If - // any of the statuses is null (indicating a missing location due to a failed mapper), - // throw a FetchFailedException. - private def convertMapStatuses( - shuffleId: Int, - reduceId: Int, - statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { - assert (statuses != null) - statuses.map { - status => - if (status == null) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing an output location for shuffle " + shuffleId)) - } else { - (status.location, decompressSize(status.compressedSizes(reduceId))) - } - } - } - - /** - * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. - * We do this by encoding the log base 1.1 of the size as an integer, which can support - * sizes up to 35 GB with at most 10% error. - */ - def compressSize(size: Long): Byte = { - if (size == 0) { - 0 - } else if (size <= 1L) { - 1 - } else { - math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte - } - } - - /** - * Decompress an 8-bit encoded block size, using the reverse operation of compressSize. - */ - def decompressSize(compressedSize: Byte): Long = { - if (compressedSize == 0) { - 0 - } else { - math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong - } - } -} diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala deleted file mode 100644 index cc1285dd95..0000000000 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ /dev/null @@ -1,703 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.nio.ByteBuffer -import java.util.{Date, HashMap => JHashMap} -import java.text.SimpleDateFormat - -import scala.collection.{mutable, Map} -import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.SequenceFile.CompressionType -import org.apache.hadoop.mapred.FileOutputCommitter -import org.apache.hadoop.mapred.FileOutputFormat -import org.apache.hadoop.mapred.SparkHadoopWriter -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputFormat - -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, - RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, SparkHadoopMapReduceUtil} -import org.apache.hadoop.security.UserGroupInformation - -import spark.partial.BoundedDouble -import spark.partial.PartialResult -import spark.rdd._ -import spark.SparkContext._ -import spark.Partitioner._ - -/** - * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. - * Import `spark.SparkContext._` at the top of your program to use these functions. - */ -class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) - extends Logging - with SparkHadoopMapReduceUtil - with Serializable { - - /** - * Generic function to combine the elements for each key using a custom set of aggregation - * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C - * Note that V and C can be different -- for example, one might group an RDD of type - * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: - * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. - * - * In addition, users can control the partitioning of the output RDD, and whether to perform - * map-side aggregation (if a mapper can produce multiple items with the same key). - */ - def combineByKey[C](createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C, - partitioner: Partitioner, - mapSideCombine: Boolean = true, - serializerClass: String = null): RDD[(K, C)] = { - if (getKeyClass().isArray) { - if (mapSideCombine) { - throw new SparkException("Cannot use map-side combining with array keys.") - } - if (partitioner.isInstanceOf[HashPartitioner]) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - } - val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) - if (self.partitioner == Some(partitioner)) { - self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - } else if (mapSideCombine) { - val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) - .setSerializer(serializerClass) - partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) - } else { - // Don't apply map-side combiner. - // A sanity check to make sure mergeCombiners is not defined. - assert(mergeCombiners == null) - val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) - values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - } - } - - /** - * Simplified version of combineByKey that hash-partitions the output RDD. - */ - def combineByKey[C](createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C, - numPartitions: Int): RDD[(K, C)] = { - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) - } - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { - // Serialize the zero value to a byte array so that we can get a new clone of it on each key - val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) - val zeroArray = new Array[Byte](zeroBuffer.limit) - zeroBuffer.get(zeroArray) - - // When deserializing, use a lazy val to create just one instance of the serializer per task - lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() - def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) - - combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) - } - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = { - foldByKey(zeroValue, new HashPartitioner(numPartitions))(func) - } - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = { - foldByKey(zeroValue, defaultPartitioner(self))(func) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. - */ - def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = { - combineByKey[V]((v: V) => v, func, func, partitioner) - } - - /** - * Merge the values for each key using an associative reduce function, but return the results - * immediately to the master as a Map. This will also perform the merging locally on each mapper - * before sending results to a reducer, similarly to a "combiner" in MapReduce. - */ - def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { - - if (getKeyClass().isArray) { - throw new SparkException("reduceByKeyLocally() does not support array keys") - } - - def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { - val map = new JHashMap[K, V] - iter.foreach { case (k, v) => - val old = map.get(k) - map.put(k, if (old == null) v else func(old, v)) - } - Iterator(map) - } - - def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = { - m2.foreach { case (k, v) => - val old = m1.get(k) - m1.put(k, if (old == null) v else func(old, v)) - } - m1 - } - - self.mapPartitions(reducePartition).reduce(mergeMaps) - } - - /** Alias for reduceByKeyLocally */ - def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) - - /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): Map[K, Long] = self.map(_._1).countByValue() - - /** - * (Experimental) Approximate version of countByKey that can return a partial result if it does - * not finish within a timeout. - */ - def countByKeyApprox(timeout: Long, confidence: Double = 0.95) - : PartialResult[Map[K, BoundedDouble]] = { - self.map(_._1).countByValueApprox(timeout, confidence) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. - */ - def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { - reduceByKey(new HashPartitioner(numPartitions), func) - } - - /** - * Group the values for each key in the RDD into a single sequence. Allows controlling the - * partitioning of the resulting key-value pair RDD by passing a Partitioner. - */ - def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { - // groupByKey shouldn't use map side combine because map side combine does not - // reduce the amount of data shuffled and requires all map side data be inserted - // into a hash table, leading to more objects in the old gen. - def createCombiner(v: V) = ArrayBuffer(v) - def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) - bufs.asInstanceOf[RDD[(K, Seq[V])]] - } - - /** - * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with into `numPartitions` partitions. - */ - def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = { - groupByKey(new HashPartitioner(numPartitions)) - } - - /** - * Return a copy of the RDD partitioned using the specified partitioner. - */ - def partitionBy(partitioner: Partitioner): RDD[(K, V)] = { - if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - new ShuffledRDD[K, V, (K, V)](self, partitioner) - } - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. - */ - def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - for (v <- vs.iterator; w <- ws.iterator) yield (v, w) - } - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to - * partition the output RDD. - */ - def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - if (ws.isEmpty) { - vs.iterator.map(v => (v, None)) - } else { - for (v <- vs.iterator; w <- ws.iterator) yield (v, Some(w)) - } - } - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to - * partition the output RDD. - */ - def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) - : RDD[(K, (Option[V], W))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - if (vs.isEmpty) { - ws.iterator.map(w => (None, w)) - } else { - for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), w) - } - } - } - - /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the - * existing partitioner/parallelism level. - */ - def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = { - combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ - * parallelism level. - */ - def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { - reduceByKey(defaultPartitioner(self), func) - } - - /** - * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with the existing partitioner/parallelism level. - */ - def groupByKey(): RDD[(K, Seq[V])] = { - groupByKey(defaultPartitioner(self)) - } - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Performs a hash join across the cluster. - */ - def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { - join(other, defaultPartitioner(self, other)) - } - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Performs a hash join across the cluster. - */ - def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = { - join(other, new HashPartitioner(numPartitions)) - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * using the existing partitioner/parallelism level. - */ - def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = { - leftOuterJoin(other, defaultPartitioner(self, other)) - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * into `numPartitions` partitions. - */ - def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = { - leftOuterJoin(other, new HashPartitioner(numPartitions)) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD using the existing partitioner/parallelism level. - */ - def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = { - rightOuterJoin(other, defaultPartitioner(self, other)) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD into the given number of partitions. - */ - def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = { - rightOuterJoin(other, new HashPartitioner(numPartitions)) - } - - /** - * Return the key-value pairs in this RDD to the master as a Map. - */ - def collectAsMap(): Map[K, V] = { - val data = self.toArray() - val map = new mutable.HashMap[K, V] - map.sizeHint(data.length) - data.foreach { case (k, v) => map.put(k, v) } - map - } - - /** - * Pass each value in the key-value pair RDD through a map function without changing the keys; - * this also retains the original RDD's partitioning. - */ - def mapValues[U](f: V => U): RDD[(K, U)] = { - val cleanF = self.context.clean(f) - new MappedValuesRDD(self, cleanF) - } - - /** - * Pass each value in the key-value pair RDD through a flatMap function without changing the - * keys; this also retains the original RDD's partitioning. - */ - def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = { - val cleanF = self.context.clean(f) - new FlatMappedValuesRDD(self, cleanF) - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { - if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) - val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) - prfs.mapValues { case Seq(vs, ws) => - (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) - } - } - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) - val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) - prfs.mapValues { case Seq(vs, w1s, w2s) => - (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]]) - } - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { - cogroup(other, defaultPartitioner(self, other)) - } - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - cogroup(other1, other2, defaultPartitioner(self, other1, other2)) - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = { - cogroup(other, new HashPartitioner(numPartitions)) - } - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - cogroup(other1, other2, new HashPartitioner(numPartitions)) - } - - /** Alias for cogroup. */ - def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { - cogroup(other, defaultPartitioner(self, other)) - } - - /** Alias for cogroup. */ - def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - cogroup(other1, other2, defaultPartitioner(self, other1, other2)) - } - - /** - * Return an RDD with the pairs from `this` whose keys are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] = - subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) - - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = - subtractByKey(other, new HashPartitioner(numPartitions)) - - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = - new SubtractedRDD[K, V, W](self, other, p) - - /** - * Return the list of values in the RDD for key `key`. This operation is done efficiently if the - * RDD has a known partitioner by only searching the partition that the key maps to. - */ - def lookup(key: K): Seq[V] = { - self.partitioner match { - case Some(p) => - val index = p.getPartition(key) - def process(it: Iterator[(K, V)]): Seq[V] = { - val buf = new ArrayBuffer[V] - for ((k, v) <- it if k == key) { - buf += v - } - buf - } - val res = self.context.runJob(self, process _, Array(index), false) - res(0) - case None => - self.filter(_._1 == key).map(_._2).collect() - } - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. - */ - def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. Compress the result with the - * supplied codec. - */ - def saveAsHadoopFile[F <: OutputFormat[K, V]]( - path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { - saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = self.context.hadoopConfiguration) { - val job = new NewAPIHadoopJob(conf) - job.setOutputKeyClass(keyClass) - job.setOutputValueClass(valueClass) - val wrappedConf = new SerializableWritable(job.getConfiguration) - NewFileOutputFormat.setOutputPath(job, new Path(path)) - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - val jobtrackerID = formatter.format(new Date()) - val stageId = self.id - def writeShard(context: spark.TaskContext, iter: Iterator[(K,V)]): Int = { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt - /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = outputFormatClass.newInstance - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] - while (iter.hasNext) { - val (k, v) = iter.next - writer.write(k, v) - } - writer.close(hadoopContext) - committer.commitTask(hadoopContext) - return 1 - } - val jobFormat = outputFormatClass.newInstance - /* apparently we need a TaskAttemptID to construct an OutputCommitter; - * however we're only going to use this local OutputCommitter for - * setupJob/commitJob, so we just use a dummy "map" task. - */ - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - jobCommitter.setupJob(jobTaskContext) - val count = self.context.runJob(self, writeShard _).sum - jobCommitter.commitJob(jobTaskContext) - jobCommitter.cleanupJob(jobTaskContext) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. Compress with the supplied codec. - */ - def saveAsHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], - codec: Class[_ <: CompressionCodec]) { - saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, - new JobConf(self.context.hadoopConfiguration), Some(codec)) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. - */ - def saveAsHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf(self.context.hadoopConfiguration), - codec: Option[Class[_ <: CompressionCodec]] = None) { - conf.setOutputKeyClass(keyClass) - conf.setOutputValueClass(valueClass) - // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug - conf.set("mapred.output.format.class", outputFormatClass.getName) - for (c <- codec) { - conf.setCompressMapOutput(true) - conf.set("mapred.output.compress", "true") - conf.setMapOutputCompressorClass(c) - conf.set("mapred.output.compression.codec", c.getCanonicalName) - conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) - } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath(conf, SparkHadoopWriter.createPathFromString(path, conf)) - saveAsHadoopDataset(conf) - } - - /** - * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for - * that storage system. The JobConf should set an OutputFormat and any output paths required - * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop - * MapReduce job. - */ - def saveAsHadoopDataset(conf: JobConf) { - val outputFormatClass = conf.getOutputFormat - val keyClass = conf.getOutputKeyClass - val valueClass = conf.getOutputValueClass - if (outputFormatClass == null) { - throw new SparkException("Output format class not set") - } - if (keyClass == null) { - throw new SparkException("Output key class not set") - } - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - - logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")") - - val writer = new SparkHadoopWriter(conf) - writer.preSetup() - - def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt - - writer.setup(context.stageId, context.splitId, attemptNumber) - writer.open() - - var count = 0 - while(iter.hasNext) { - val record = iter.next() - count += 1 - writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) - } - - writer.close() - writer.commit() - } - - self.context.runJob(self, writeToFile _) - writer.commitJob() - writer.cleanup() - } - - /** - * Return an RDD with the keys of each tuple. - */ - def keys: RDD[K] = self.map(_._1) - - /** - * Return an RDD with the values of each tuple. - */ - def values: RDD[V] = self.map(_._2) - - private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure - - private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure -} - - -private[spark] object Manifests { - val seqSeqManifest = classManifest[Seq[Seq[_]]] -} diff --git a/core/src/main/scala/spark/Partition.scala b/core/src/main/scala/spark/Partition.scala deleted file mode 100644 index 2a4edcec98..0000000000 --- a/core/src/main/scala/spark/Partition.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -/** - * A partition of an RDD. - */ -trait Partition extends Serializable { - /** - * Get the split's index within its parent RDD - */ - def index: Int - - // A better default implementation of HashCode - override def hashCode(): Int = index -} diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala deleted file mode 100644 index 65da8235d7..0000000000 --- a/core/src/main/scala/spark/Partitioner.scala +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -/** - * An object that defines how the elements in a key-value pair RDD are partitioned by key. - * Maps each key to a partition ID, from 0 to `numPartitions - 1`. - */ -abstract class Partitioner extends Serializable { - def numPartitions: Int - def getPartition(key: Any): Int -} - -object Partitioner { - /** - * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. - * - * If any of the RDDs already has a partitioner, choose that one. - * - * Otherwise, we use a default HashPartitioner. For the number of partitions, if - * spark.default.parallelism is set, then we'll use the value from SparkContext - * defaultParallelism, otherwise we'll use the max number of upstream partitions. - * - * Unless spark.default.parallelism is set, He number of partitions will be the - * same as the number of partitions in the largest upstream RDD, as this should - * be least likely to cause out-of-memory errors. - * - * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. - */ - def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { - val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse - for (r <- bySize if r.partitioner != None) { - return r.partitioner.get - } - if (System.getProperty("spark.default.parallelism") != null) { - return new HashPartitioner(rdd.context.defaultParallelism) - } else { - return new HashPartitioner(bySize.head.partitions.size) - } - } -} - -/** - * A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. - * - * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, - * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will - * produce an unexpected or incorrect result. - */ -class HashPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions - - def getPartition(key: Any): Int = key match { - case null => 0 - case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) - } - - override def equals(other: Any): Boolean = other match { - case h: HashPartitioner => - h.numPartitions == numPartitions - case _ => - false - } -} - -/** - * A [[spark.Partitioner]] that partitions sortable records by range into roughly equal ranges. - * Determines the ranges by sampling the RDD passed in. - */ -class RangePartitioner[K <% Ordered[K]: ClassManifest, V]( - partitions: Int, - @transient rdd: RDD[_ <: Product2[K,V]], - private val ascending: Boolean = true) - extends Partitioner { - - // An array of upper bounds for the first (partitions - 1) partitions - private val rangeBounds: Array[K] = { - if (partitions == 1) { - Array() - } else { - val rddSize = rdd.count() - val maxSampleSize = partitions * 20.0 - val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) - val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _) - if (rddSample.length == 0) { - Array() - } else { - val bounds = new Array[K](partitions - 1) - for (i <- 0 until partitions - 1) { - val index = (rddSample.length - 1) * (i + 1) / partitions - bounds(i) = rddSample(index) - } - bounds - } - } - } - - def numPartitions = partitions - - def getPartition(key: Any): Int = { - // TODO: Use a binary search here if number of partitions is large - val k = key.asInstanceOf[K] - var partition = 0 - while (partition < rangeBounds.length && k > rangeBounds(partition)) { - partition += 1 - } - if (ascending) { - partition - } else { - rangeBounds.length - partition - } - } - - override def equals(other: Any): Boolean = other match { - case r: RangePartitioner[_,_] => - r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending - case _ => - false - } -} diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala deleted file mode 100644 index 25a6951732..0000000000 --- a/core/src/main/scala/spark/RDD.scala +++ /dev/null @@ -1,957 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.util.Random - -import scala.collection.Map -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.ArrayBuffer - -import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.TextOutputFormat - -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -import spark.Partitioner._ -import spark.api.java.JavaRDD -import spark.partial.BoundedDouble -import spark.partial.CountEvaluator -import spark.partial.GroupedCountEvaluator -import spark.partial.PartialResult -import spark.rdd.CoalescedRDD -import spark.rdd.CartesianRDD -import spark.rdd.FilteredRDD -import spark.rdd.FlatMappedRDD -import spark.rdd.GlommedRDD -import spark.rdd.MappedRDD -import spark.rdd.MapPartitionsRDD -import spark.rdd.MapPartitionsWithIndexRDD -import spark.rdd.PipedRDD -import spark.rdd.SampledRDD -import spark.rdd.ShuffledRDD -import spark.rdd.UnionRDD -import spark.rdd.ZippedRDD -import spark.rdd.ZippedPartitionsRDD2 -import spark.rdd.ZippedPartitionsRDD3 -import spark.rdd.ZippedPartitionsRDD4 -import spark.storage.StorageLevel -import spark.util.BoundedPriorityQueue - -import SparkContext._ - -/** - * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, - * partitioned collection of elements that can be operated on in parallel. This class contains the - * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, - * [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such - * as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations available only on - * RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations available on RDDs - * that can be saved as SequenceFiles. These operations are automatically available on any RDD of - * the right type (e.g. RDD[(Int, Int)] through implicit conversions when you - * `import spark.SparkContext._`. - * - * Internally, each RDD is characterized by five main properties: - * - * - A list of partitions - * - A function for computing each split - * - A list of dependencies on other RDDs - * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned) - * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for - * an HDFS file) - * - * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD - * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for - * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details - * on RDD internals. - */ -abstract class RDD[T: ClassManifest]( - @transient private var sc: SparkContext, - @transient private var deps: Seq[Dependency[_]] - ) extends Serializable with Logging { - - /** Construct an RDD with just a one-to-one dependency on one parent */ - def this(@transient oneParent: RDD[_]) = - this(oneParent.context , List(new OneToOneDependency(oneParent))) - - // ======================================================================= - // Methods that should be implemented by subclasses of RDD - // ======================================================================= - - /** Implemented by subclasses to compute a given partition. */ - def compute(split: Partition, context: TaskContext): Iterator[T] - - /** - * Implemented by subclasses to return the set of partitions in this RDD. This method will only - * be called once, so it is safe to implement a time-consuming computation in it. - */ - protected def getPartitions: Array[Partition] - - /** - * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only - * be called once, so it is safe to implement a time-consuming computation in it. - */ - protected def getDependencies: Seq[Dependency[_]] = deps - - /** Optionally overridden by subclasses to specify placement preferences. */ - protected def getPreferredLocations(split: Partition): Seq[String] = Nil - - /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None - - // ======================================================================= - // Methods and fields available on all RDDs - // ======================================================================= - - /** The SparkContext that created this RDD. */ - def sparkContext: SparkContext = sc - - /** A unique ID for this RDD (within its SparkContext). */ - val id: Int = sc.newRddId() - - /** A friendly name for this RDD */ - var name: String = null - - /** Assign a name to this RDD */ - def setName(_name: String) = { - name = _name - this - } - - /** User-defined generator of this RDD*/ - var generator = Utils.getCallSiteInfo.firstUserClass - - /** Reset generator*/ - def setGenerator(_generator: String) = { - generator = _generator - } - - /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. This can only be used to assign a new storage level if the RDD does not - * have a storage level set yet.. - */ - def persist(newLevel: StorageLevel): RDD[T] = { - // TODO: Handle changes of StorageLevel - if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { - throw new UnsupportedOperationException( - "Cannot change storage level of an RDD after it was already assigned a level") - } - storageLevel = newLevel - // Register the RDD with the SparkContext - sc.persistentRdds(id) = this - this - } - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY) - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def cache(): RDD[T] = persist() - - /** - * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. - * - * @param blocking Whether to block until all blocks are deleted. - * @return This RDD. - */ - def unpersist(blocking: Boolean = true): RDD[T] = { - logInfo("Removing RDD " + id + " from persistence list") - sc.env.blockManager.master.removeRdd(id, blocking) - sc.persistentRdds.remove(id) - storageLevel = StorageLevel.NONE - this - } - - /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ - def getStorageLevel = storageLevel - - // Our dependencies and partitions will be gotten by calling subclass's methods below, and will - // be overwritten when we're checkpointed - private var dependencies_ : Seq[Dependency[_]] = null - @transient private var partitions_ : Array[Partition] = null - - /** An Option holding our checkpoint RDD, if we are checkpointed */ - private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) - - /** - * Get the list of dependencies of this RDD, taking into account whether the - * RDD is checkpointed or not. - */ - final def dependencies: Seq[Dependency[_]] = { - checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse { - if (dependencies_ == null) { - dependencies_ = getDependencies - } - dependencies_ - } - } - - /** - * Get the array of partitions of this RDD, taking into account whether the - * RDD is checkpointed or not. - */ - final def partitions: Array[Partition] = { - checkpointRDD.map(_.partitions).getOrElse { - if (partitions_ == null) { - partitions_ = getPartitions - } - partitions_ - } - } - - /** - * Get the preferred locations of a partition (as hostnames), taking into account whether the - * RDD is checkpointed. - */ - final def preferredLocations(split: Partition): Seq[String] = { - checkpointRDD.map(_.getPreferredLocations(split)).getOrElse { - getPreferredLocations(split) - } - } - - /** - * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. - * This should ''not'' be called by users directly, but is available for implementors of custom - * subclasses of RDD. - */ - final def iterator(split: Partition, context: TaskContext): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) - } else { - computeOrReadCheckpoint(split, context) - } - } - - /** - * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. - */ - private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { - if (isCheckpointed) { - firstParent[T].iterator(split, context) - } else { - compute(split, context) - } - } - - // Transformations (return a new RDD) - - /** - * Return a new RDD by applying a function to all elements of this RDD. - */ - def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) - - /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. - */ - def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] = - new FlatMappedRDD(this, sc.clean(f)) - - /** - * Return a new RDD containing only the elements that satisfy a predicate. - */ - def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f)) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(numPartitions: Int): RDD[T] = - map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) - - def distinct(): RDD[T] = distinct(partitions.size) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T] = { - if (shuffle) { - // include a shuffle step so that our upstream tasks are still distributed - new CoalescedRDD( - new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)), - new HashPartitioner(numPartitions)), - numPartitions).keys - } else { - new CoalescedRDD(this, numPartitions) - } - } - - /** - * Return a sampled subset of this RDD. - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = - new SampledRDD(this, withReplacement, fraction, seed) - - def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { - var fraction = 0.0 - var total = 0 - val multiplier = 3.0 - val initialCount = this.count() - var maxSelected = 0 - - if (num < 0) { - throw new IllegalArgumentException("Negative number of elements requested") - } - - if (initialCount > Integer.MAX_VALUE - 1) { - maxSelected = Integer.MAX_VALUE - 1 - } else { - maxSelected = initialCount.toInt - } - - if (num > initialCount && !withReplacement) { - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - } else { - fraction = multiplier * (num + 1) / initialCount - total = num - } - - val rand = new Random(seed) - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - - // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for thei initial size - while (samples.length < total) { - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - } - - Utils.randomizeInPlace(samples, rand).take(total) - } - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def ++(other: RDD[T]): RDD[T] = this.union(other) - - /** - * Return an RDD created by coalescing all elements within each partition into an array. - */ - def glom(): RDD[Array[T]] = new GlommedRDD(this) - - /** - * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of - * elements (a, b) where a is in `this` and b is in `other`. - */ - def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) - - /** - * Return an RDD of grouped items. - */ - def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = - groupBy[K](f, defaultPartitioner(this)) - - /** - * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements - * mapping to that key. - */ - def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = - groupBy(f, new HashPartitioner(numPartitions)) - - /** - * Return an RDD of grouped items. - */ - def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { - val cleanF = sc.clean(f) - this.map(t => (cleanF(t), t)).groupByKey(p) - } - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: String): RDD[String] = new PipedRDD(this, command) - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: String, env: Map[String, String]): RDD[String] = - new PipedRDD(this, command, env) - - - /** - * Return an RDD created by piping elements to a forked external process. - * The print behavior can be customized by providing two functions. - * - * @param command command to run in forked process. - * @param env environment variables to set. - * @param printPipeContext Before piping elements, this function is called as an oppotunity - * to pipe context data. Print line function (like out.println) will be - * passed as printPipeContext's parameter. - * @param printRDDElement Use this function to customize how to pipe elements. This function - * will be called with each RDD element as the 1st parameter, and the - * print line function (like out.println()) as the 2nd parameter. - * An example of pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2){f(e)} - * @return the result RDD - */ - def pipe( - command: Seq[String], - env: Map[String, String] = Map(), - printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = - new PipedRDD(this, command, env, - if (printPipeContext ne null) sc.clean(printPipeContext) else null, - if (printRDDElement ne null) sc.clean(printRDDElement) else null) - - /** - * Return a new RDD by applying a function to each partition of this RDD. - */ - def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) - - /** - * Return a new RDD by applying a function to each partition of this RDD, while tracking the index - * of the original partition. - */ - def mapPartitionsWithIndex[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) - - /** - * Return a new RDD by applying a function to each partition of this RDD, while tracking the index - * of the original partition. - */ - @deprecated("use mapPartitionsWithIndex", "0.7.0") - def mapPartitionsWithSplit[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) - - /** - * Maps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => U): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.map(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) - } - - /** - * FlatMaps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => Seq[U]): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) - } - - /** - * Applies f to each element of this RDD, where f takes an additional parameter of type A. - * This additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def foreachWith[A: ClassManifest](constructA: Int => A) - (f:(T, A) => Unit) { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.map(t => {f(t, a); t}) - } - (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) - } - - /** - * Filters this RDD with p, where p takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def filterWith[A: ClassManifest](constructA: Int => A) - (p:(T, A) => Boolean): RDD[T] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.filter(t => p(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) - } - - /** - * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, - * second element in each RDD, etc. Assumes that the two RDDs have the *same number of - * partitions* and the *same number of elements in each partition* (e.g. one was made through - * a map on the other). - */ - def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) - - /** - * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by - * applying a function to the zipped partitions. Assumes that all the RDDs have the - * *same number of partitions*, but does *not* require them to have the same number - * of elements in each partition. - */ - def zipPartitions[B: ClassManifest, V: ClassManifest] - (rdd2: RDD[B]) - (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) - - def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest] - (rdd2: RDD[B], rdd3: RDD[C]) - (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) - - def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest] - (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) - (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) - - - // Actions (launch a job to return a value to the user program) - - /** - * Applies a function f to all elements of this RDD. - */ - def foreach(f: T => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) - } - - /** - * Applies a function f to each partition of this RDD. - */ - def foreachPartition(f: Iterator[T] => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) - } - - /** - * Return an array that contains all of the elements in this RDD. - */ - def collect(): Array[T] = { - val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) - Array.concat(results: _*) - } - - /** - * Return an array that contains all of the elements in this RDD. - */ - def toArray(): Array[T] = collect() - - /** - * Return an RDD that contains all matching values by applying `f`. - */ - def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { - filter(f.isDefinedAt).map(f) - } - - /** - * Return an RDD with the elements from `this` that are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtract(other: RDD[T]): RDD[T] = - subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: RDD[T], numPartitions: Int): RDD[T] = - subtract(other, new HashPartitioner(numPartitions)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: RDD[T], p: Partitioner): RDD[T] = { - if (partitioner == Some(p)) { - // Our partitioner knows how to handle T (which, since we have a partitioner, is - // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples - val p2 = new Partitioner() { - override def numPartitions = p.numPartitions - override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) - } - // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies - // anyway, and when calling .keys, will not have a partitioner set, even though - // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be - // partitioned by the right/real keys (e.g. p). - this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys - } else { - this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys - } - } - - /** - * Reduces the elements of this RDD using the specified commutative and associative binary operator. - */ - def reduce(f: (T, T) => T): T = { - val cleanF = sc.clean(f) - val reducePartition: Iterator[T] => Option[T] = iter => { - if (iter.hasNext) { - Some(iter.reduceLeft(cleanF)) - } else { - None - } - } - var jobResult: Option[T] = None - val mergeResult = (index: Int, taskResult: Option[T]) => { - if (taskResult != None) { - jobResult = jobResult match { - case Some(value) => Some(f(value, taskResult.get)) - case None => taskResult - } - } - } - sc.runJob(this, reducePartition, mergeResult) - // Get the final result out of our Option, or throw an exception if the RDD was empty - jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) - } - - /** - * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. - */ - def fold(zeroValue: T)(op: (T, T) => T): T = { - // Clone the zero value since we will also be serializing it as part of tasks - var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) - val cleanOp = sc.clean(op) - val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp) - val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult) - sc.runJob(this, foldPartition, mergeResult) - jobResult - } - - /** - * Aggregate the elements of each partition, and then the results for all the partitions, using - * given combine functions and a neutral "zero value". This function can return a different result - * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U - * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are - * allowed to modify and return their first argument instead of creating a new U to avoid memory - * allocation. - */ - def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { - // Clone the zero value since we will also be serializing it as part of tasks - var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) - val cleanSeqOp = sc.clean(seqOp) - val cleanCombOp = sc.clean(combOp) - val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) - sc.runJob(this, aggregatePartition, mergeResult) - jobResult - } - - /** - * Return the number of elements in the RDD. - */ - def count(): Long = { - sc.runJob(this, (iter: Iterator[T]) => { - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next() - } - result - }).sum - } - - /** - * (Experimental) Approximate version of count() that returns a potentially incomplete result - * within a timeout, even if not all tasks have finished. - */ - def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next() - } - result - } - val evaluator = new CountEvaluator(partitions.size, confidence) - sc.runApproximateJob(this, countElements, evaluator, timeout) - } - - /** - * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final - * combine step happens locally on the master, equivalent to running a single reduce task. - */ - def countByValue(): Map[T, Long] = { - if (elementClassManifest.erasure.isArray) { - throw new SparkException("countByValue() does not support arrays") - } - // TODO: This should perhaps be distributed by default. - def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { - val map = new OLMap[T] - while (iter.hasNext) { - val v = iter.next() - map.put(v, map.getLong(v) + 1L) - } - Iterator(map) - } - def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = { - val iter = m2.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue) - } - return m1 - } - val myResult = mapPartitions(countPartition).reduce(mergeMaps) - myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map - } - - /** - * (Experimental) Approximate version of countByValue(). - */ - def countByValueApprox( - timeout: Long, - confidence: Double = 0.95 - ): PartialResult[Map[T, BoundedDouble]] = { - if (elementClassManifest.erasure.isArray) { - throw new SparkException("countByValueApprox() does not support arrays") - } - val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => - val map = new OLMap[T] - while (iter.hasNext) { - val v = iter.next() - map.put(v, map.getLong(v) + 1L) - } - map - } - val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) - sc.runApproximateJob(this, countPartition, evaluator, timeout) - } - - /** - * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so - * it will be slow if a lot of partitions are required. In that case, use collect() to get the - * whole RDD instead. - */ - def take(num: Int): Array[T] = { - if (num == 0) { - return new Array[T](0) - } - val buf = new ArrayBuffer[T] - var p = 0 - while (buf.size < num && p < partitions.size) { - val left = num - buf.size - val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true) - buf ++= res(0) - if (buf.size == num) - return buf.toArray - p += 1 - } - return buf.toArray - } - - /** - * Return the first element in this RDD. - */ - def first(): T = take(1) match { - case Array(t) => t - case _ => throw new UnsupportedOperationException("empty collection") - } - - /** - * Returns the top K elements from this RDD as defined by - * the specified implicit Ordering[T]. - * @param num the number of top elements to return - * @param ord the implicit ordering for T - * @return an array of top elements - */ - def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { - mapPartitions { items => - val queue = new BoundedPriorityQueue[T](num) - queue ++= items - Iterator.single(queue) - }.reduce { (queue1, queue2) => - queue1 ++= queue2 - queue1 - }.toArray.sorted(ord.reverse) - } - - /** - * Returns the first K elements from this RDD as defined by - * the specified implicit Ordering[T] and maintains the - * ordering. - * @param num the number of top elements to return - * @param ord the implicit ordering for T - * @return an array of top elements - */ - def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse) - - /** - * Save this RDD as a text file, using string representations of elements. - */ - def saveAsTextFile(path: String) { - this.map(x => (NullWritable.get(), new Text(x.toString))) - .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path) - } - - /** - * Save this RDD as a compressed text file, using string representations of elements. - */ - def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { - this.map(x => (NullWritable.get(), new Text(x.toString))) - .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec) - } - - /** - * Save this RDD as a SequenceFile of serialized objects. - */ - def saveAsObjectFile(path: String) { - this.mapPartitions(iter => iter.grouped(10).map(_.toArray)) - .map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x)))) - .saveAsSequenceFile(path) - } - - /** - * Creates tuples of the elements in this RDD by applying `f`. - */ - def keyBy[K](f: T => K): RDD[(K, T)] = { - map(x => (f(x), x)) - } - - /** A private method for tests, to look at the contents of each partition */ - private[spark] def collectPartitions(): Array[Array[T]] = { - sc.runJob(this, (iter: Iterator[T]) => iter.toArray) - } - - /** - * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint - * directory set with SparkContext.setCheckpointDir() and all references to its parent - * RDDs will be removed. This function must be called before any job has been - * executed on this RDD. It is strongly recommended that this RDD is persisted in - * memory, otherwise saving it on a file will require recomputation. - */ - def checkpoint() { - if (context.checkpointDir.isEmpty) { - throw new Exception("Checkpoint directory has not been set in the SparkContext") - } else if (checkpointData.isEmpty) { - checkpointData = Some(new RDDCheckpointData(this)) - checkpointData.get.markForCheckpoint() - } - } - - /** - * Return whether this RDD has been checkpointed or not - */ - def isCheckpointed: Boolean = { - checkpointData.map(_.isCheckpointed).getOrElse(false) - } - - /** - * Gets the name of the file to which this RDD was checkpointed - */ - def getCheckpointFile: Option[String] = { - checkpointData.flatMap(_.getCheckpointFile) - } - - // ======================================================================= - // Other internal methods and fields - // ======================================================================= - - private var storageLevel: StorageLevel = StorageLevel.NONE - - /** Record user function generating this RDD. */ - private[spark] val origin = Utils.formatSparkCallSite - - private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] - - private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None - - /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassManifest] = { - dependencies.head.rdd.asInstanceOf[RDD[U]] - } - - /** The [[spark.SparkContext]] that this RDD was created on. */ - def context = sc - - // Avoid handling doCheckpoint multiple times to prevent excessive recursion - private var doCheckpointCalled = false - - /** - * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler - * after a job using this RDD has completed (therefore the RDD has been materialized and - * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. - */ - private[spark] def doCheckpoint() { - if (!doCheckpointCalled) { - doCheckpointCalled = true - if (checkpointData.isDefined) { - checkpointData.get.doCheckpoint() - } else { - dependencies.foreach(_.rdd.doCheckpoint()) - } - } - } - - /** - * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) - * created from the checkpoint file, and forget its old dependencies and partitions. - */ - private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { - clearDependencies() - partitions_ = null - deps = null // Forget the constructor argument for dependencies too - } - - /** - * Clears the dependencies of this RDD. This method must ensure that all references - * to the original parent RDDs is removed to enable the parent RDDs to be garbage - * collected. Subclasses of RDD may override this method for implementing their own cleaning - * logic. See [[spark.rdd.UnionRDD]] for an example. - */ - protected def clearDependencies() { - dependencies_ = null - } - - /** A description of this RDD and its recursive dependencies for debugging. */ - def toDebugString: String = { - def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { - Seq(prefix + rdd + " (" + rdd.partitions.size + " partitions)") ++ - rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) - } - debugString(this).mkString("\n") - } - - override def toString: String = "%s%s[%d] at %s".format( - Option(name).map(_ + " ").getOrElse(""), - getClass.getSimpleName, - id, - origin) - - def toJavaRDD() : JavaRDD[T] = { - new JavaRDD(this)(elementClassManifest) - } - -} diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala deleted file mode 100644 index b615f820eb..0000000000 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration -import rdd.{CheckpointRDD, CoalescedRDD} -import scheduler.{ResultTask, ShuffleMapTask} - -/** - * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] - */ -private[spark] object CheckpointState extends Enumeration { - type CheckpointState = Value - val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value -} - -/** - * This class contains all the information related to RDD checkpointing. Each instance of this class - * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as, - * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations - * of the checkpointed RDD. - */ -private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) - extends Logging with Serializable { - - import CheckpointState._ - - // The checkpoint state of the associated RDD. - var cpState = Initialized - - // The file to which the associated RDD has been checkpointed to - @transient var cpFile: Option[String] = None - - // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - var cpRDD: Option[RDD[T]] = None - - // Mark the RDD for checkpointing - def markForCheckpoint() { - RDDCheckpointData.synchronized { - if (cpState == Initialized) cpState = MarkedForCheckpoint - } - } - - // Is the RDD already checkpointed - def isCheckpointed: Boolean = { - RDDCheckpointData.synchronized { cpState == Checkpointed } - } - - // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile: Option[String] = { - RDDCheckpointData.synchronized { cpFile } - } - - // Do the checkpointing of the RDD. Called after the first job using that RDD is over. - def doCheckpoint() { - // If it is marked for checkpointing AND checkpointing is not already in progress, - // then set it to be in progress, else return - RDDCheckpointData.synchronized { - if (cpState == MarkedForCheckpoint) { - cpState = CheckpointingInProgress - } else { - return - } - } - - // Create the output path for the checkpoint - val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) - val fs = path.getFileSystem(new Configuration()) - if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) - } - - // Save to file, and reload it as an RDD - rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _) - val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - - // Change the dependencies and partitions of the RDD - RDDCheckpointData.synchronized { - cpFile = Some(path.toString) - cpRDD = Some(newRDD) - rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions - cpState = Checkpointed - RDDCheckpointData.clearTaskCaches() - logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) - } - } - - // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Partition): Seq[String] = { - RDDCheckpointData.synchronized { - cpRDD.get.preferredLocations(split) - } - } - - def getPartitions: Array[Partition] = { - RDDCheckpointData.synchronized { - cpRDD.get.partitions - } - } - - def checkpointRDD: Option[RDD[T]] = { - RDDCheckpointData.synchronized { - cpRDD - } - } -} - -private[spark] object RDDCheckpointData { - def clearTaskCaches() { - ShuffleMapTask.clearCache() - ResultTask.clearCache() - } -} diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala deleted file mode 100644 index 9f30b7f22f..0000000000 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.EOFException -import java.net.URL -import java.io.ObjectInputStream -import java.util.concurrent.atomic.AtomicLong -import java.util.HashSet -import java.util.Random -import java.util.Date - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Map -import scala.collection.mutable.HashMap - -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputFormat -import org.apache.hadoop.mapred.TextOutputFormat -import org.apache.hadoop.mapred.SequenceFileOutputFormat -import org.apache.hadoop.mapred.OutputCommitter -import org.apache.hadoop.mapred.FileOutputCommitter -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.Writable -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.Text - -import spark.SparkContext._ - -/** - * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, - * through an implicit conversion. Note that this can't be part of PairRDDFunctions because - * we need more implicit parameters to convert our keys and values to Writable. - * - * Users should import `spark.SparkContext._` at the top of their program to use these functions. - */ -class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest]( - self: RDD[(K, V)]) - extends Logging - with Serializable { - - private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = { - val c = { - if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { - classManifest[T].erasure - } else { - // We get the type of the Writable class by looking at the apply method which converts - // from T to Writable. Since we have two apply methods we filter out the one which - // is not of the form "java.lang.Object apply(java.lang.Object)" - implicitly[T => Writable].getClass.getDeclaredMethods().filter( - m => m.getReturnType().toString != "class java.lang.Object" && - m.getName() == "apply")(0).getReturnType - - } - // TODO: use something like WritableConverter to avoid reflection - } - c.asInstanceOf[Class[_ <: Writable]] - } - - /** - * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key - * and value types. If the key or value are Writable, then we use their classes directly; - * otherwise we map primitive types such as Int and Double to IntWritable, DoubleWritable, etc, - * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported - * file system. - */ - def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { - def anyToWritable[U <% Writable](u: U): Writable = u - - val keyClass = getWritableClass[K] - val valueClass = getWritableClass[V] - val convertKey = !classOf[Writable].isAssignableFrom(self.getKeyClass) - val convertValue = !classOf[Writable].isAssignableFrom(self.getValueClass) - - logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" ) - val format = classOf[SequenceFileOutputFormat[Writable, Writable]] - val jobConf = new JobConf(self.context.hadoopConfiguration) - if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) - } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) - } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) - } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) - } - } -} diff --git a/core/src/main/scala/spark/SerializableWritable.scala b/core/src/main/scala/spark/SerializableWritable.scala deleted file mode 100644 index 936d8e6241..0000000000 --- a/core/src/main/scala/spark/SerializableWritable.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ - -import org.apache.hadoop.io.ObjectWritable -import org.apache.hadoop.io.Writable -import org.apache.hadoop.conf.Configuration - -class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { - def value = t - override def toString = t.toString - - private def writeObject(out: ObjectOutputStream) { - out.defaultWriteObject() - new ObjectWritable(t).write(out) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - val ow = new ObjectWritable() - ow.setConf(new Configuration()) - ow.readFields(in) - t = ow.get().asInstanceOf[T] - } -} diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala deleted file mode 100644 index a6839cf7a4..0000000000 --- a/core/src/main/scala/spark/ShuffleFetcher.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import spark.executor.TaskMetrics -import spark.serializer.Serializer - - -private[spark] abstract class ShuffleFetcher { - - /** - * Fetch the shuffle outputs for a given ShuffleDependency. - * @return An iterator over the elements of the fetched shuffle outputs. - */ - def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, - serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T] - - /** Stop the fetcher */ - def stop() {} -} diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala deleted file mode 100644 index 6cc57566d7..0000000000 --- a/core/src/main/scala/spark/SizeEstimator.scala +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.lang.reflect.Field -import java.lang.reflect.Modifier -import java.lang.reflect.{Array => JArray} -import java.util.IdentityHashMap -import java.util.concurrent.ConcurrentHashMap -import java.util.Random - -import javax.management.MBeanServer -import java.lang.management.ManagementFactory - -import scala.collection.mutable.ArrayBuffer - -import it.unimi.dsi.fastutil.ints.IntOpenHashSet - -/** - * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in - * memory-aware caches. - * - * Based on the following JavaWorld article: - * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html - */ -private[spark] object SizeEstimator extends Logging { - - // Sizes of primitive types - private val BYTE_SIZE = 1 - private val BOOLEAN_SIZE = 1 - private val CHAR_SIZE = 2 - private val SHORT_SIZE = 2 - private val INT_SIZE = 4 - private val LONG_SIZE = 8 - private val FLOAT_SIZE = 4 - private val DOUBLE_SIZE = 8 - - // Alignment boundary for objects - // TODO: Is this arch dependent ? - private val ALIGN_SIZE = 8 - - // A cache of ClassInfo objects for each class - private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo] - - // Object and pointer sizes are arch dependent - private var is64bit = false - - // Size of an object reference - // Based on https://wikis.oracle.com/display/HotSpotInternals/CompressedOops - private var isCompressedOops = false - private var pointerSize = 4 - - // Minimum size of a java.lang.Object - private var objectSize = 8 - - initialize() - - // Sets object size, pointer size based on architecture and CompressedOops settings - // from the JVM. - private def initialize() { - is64bit = System.getProperty("os.arch").contains("64") - isCompressedOops = getIsCompressedOops - - objectSize = if (!is64bit) 8 else { - if(!isCompressedOops) { - 16 - } else { - 12 - } - } - pointerSize = if (is64bit && !isCompressedOops) 8 else 4 - classInfos.clear() - classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil)) - } - - private def getIsCompressedOops : Boolean = { - if (System.getProperty("spark.test.useCompressedOops") != null) { - return System.getProperty("spark.test.useCompressedOops").toBoolean - } - - try { - val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" - val server = ManagementFactory.getPlatformMBeanServer() - - // NOTE: This should throw an exception in non-Sun JVMs - val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") - val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", - Class.forName("java.lang.String")) - - val bean = ManagementFactory.newPlatformMXBeanProxy(server, - hotSpotMBeanName, hotSpotMBeanClass) - // TODO: We could use reflection on the VMOption returned ? - return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") - } catch { - case e: Exception => { - // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB - val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) - val guessInWords = if (guess) "yes" else "not" - logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords) - return guess - } - } - } - - /** - * The state of an ongoing size estimation. Contains a stack of objects to visit as well as an - * IdentityHashMap of visited objects, and provides utility methods for enqueueing new objects - * to visit. - */ - private class SearchState(val visited: IdentityHashMap[AnyRef, AnyRef]) { - val stack = new ArrayBuffer[AnyRef] - var size = 0L - - def enqueue(obj: AnyRef) { - if (obj != null && !visited.containsKey(obj)) { - visited.put(obj, null) - stack += obj - } - } - - def isFinished(): Boolean = stack.isEmpty - - def dequeue(): AnyRef = { - val elem = stack.last - stack.trimEnd(1) - return elem - } - } - - /** - * Cached information about each class. We remember two things: the "shell size" of the class - * (size of all non-static fields plus the java.lang.Object size), and any fields that are - * pointers to objects. - */ - private class ClassInfo( - val shellSize: Long, - val pointerFields: List[Field]) {} - - def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef]) - - private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = { - val state = new SearchState(visited) - state.enqueue(obj) - while (!state.isFinished) { - visitSingleObject(state.dequeue(), state) - } - return state.size - } - - private def visitSingleObject(obj: AnyRef, state: SearchState) { - val cls = obj.getClass - if (cls.isArray) { - visitArray(obj, cls, state) - } else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) { - // Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses - // the size estimator since it references the whole REPL. Do nothing in this case. In - // general all ClassLoaders and Classes will be shared between objects anyway. - } else { - val classInfo = getClassInfo(cls) - state.size += classInfo.shellSize - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) - } - } - } - - // Estimat the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. - private val ARRAY_SIZE_FOR_SAMPLING = 200 - private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING - - private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) { - val length = JArray.getLength(array) - val elementClass = cls.getComponentType - - // Arrays have object header and length field which is an integer - var arrSize: Long = alignSize(objectSize + INT_SIZE) - - if (elementClass.isPrimitive) { - arrSize += alignSize(length * primitiveSize(elementClass)) - state.size += arrSize - } else { - arrSize += alignSize(length * pointerSize) - state.size += arrSize - - if (length <= ARRAY_SIZE_FOR_SAMPLING) { - for (i <- 0 until length) { - state.enqueue(JArray.get(array, i)) - } - } else { - // Estimate the size of a large array by sampling elements without replacement. - var size = 0.0 - val rand = new Random(42) - val drawn = new IntOpenHashSet(ARRAY_SAMPLE_SIZE) - for (i <- 0 until ARRAY_SAMPLE_SIZE) { - var index = 0 - do { - index = rand.nextInt(length) - } while (drawn.contains(index)) - drawn.add(index) - val elem = JArray.get(array, index) - size += SizeEstimator.estimate(elem, state.visited) - } - state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong - } - } - } - - private def primitiveSize(cls: Class[_]): Long = { - if (cls == classOf[Byte]) - BYTE_SIZE - else if (cls == classOf[Boolean]) - BOOLEAN_SIZE - else if (cls == classOf[Char]) - CHAR_SIZE - else if (cls == classOf[Short]) - SHORT_SIZE - else if (cls == classOf[Int]) - INT_SIZE - else if (cls == classOf[Long]) - LONG_SIZE - else if (cls == classOf[Float]) - FLOAT_SIZE - else if (cls == classOf[Double]) - DOUBLE_SIZE - else throw new IllegalArgumentException( - "Non-primitive class " + cls + " passed to primitiveSize()") - } - - /** - * Get or compute the ClassInfo for a given class. - */ - private def getClassInfo(cls: Class[_]): ClassInfo = { - // Check whether we've already cached a ClassInfo for this class - val info = classInfos.get(cls) - if (info != null) { - return info - } - - val parent = getClassInfo(cls.getSuperclass) - var shellSize = parent.shellSize - var pointerFields = parent.pointerFields - - for (field <- cls.getDeclaredFields) { - if (!Modifier.isStatic(field.getModifiers)) { - val fieldClass = field.getType - if (fieldClass.isPrimitive) { - shellSize += primitiveSize(fieldClass) - } else { - field.setAccessible(true) // Enable future get()'s on this field - shellSize += pointerSize - pointerFields = field :: pointerFields - } - } - } - - shellSize = alignSize(shellSize) - - // Create and cache a new ClassInfo - val newInfo = new ClassInfo(shellSize, pointerFields) - classInfos.put(cls, newInfo) - return newInfo - } - - private def alignSize(size: Long): Long = { - val rem = size % ALIGN_SIZE - return if (rem == 0) size else (size + ALIGN_SIZE - rem) - } -} diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala deleted file mode 100644 index 7ce9505b9c..0000000000 --- a/core/src/main/scala/spark/SparkContext.scala +++ /dev/null @@ -1,995 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ -import java.net.URI -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.Map -import scala.collection.generic.Growable -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.util.DynamicVariable - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.ArrayWritable -import org.apache.hadoop.io.BooleanWritable -import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.DoubleWritable -import org.apache.hadoop.io.FloatWritable -import org.apache.hadoop.io.IntWritable -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapred.FileInputFormat -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.SequenceFileInputFormat -import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} - -import org.apache.mesos.MesosNativeLibrary - -import spark.deploy.LocalSparkCluster -import spark.partial.{ApproximateEvaluator, PartialResult} -import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD, - OrderedRDDFunctions} -import spark.scheduler._ -import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, - ClusterScheduler, Schedulable, SchedulingMode} -import spark.scheduler.local.LocalScheduler -import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import spark.storage.{StorageStatus, StorageUtils, RDDInfo, BlockManagerSource} -import spark.ui.SparkUI -import spark.util.{MetadataCleaner, TimeStampedHashMap} -import scala.Some -import spark.scheduler.StageInfo -import spark.storage.RDDInfo -import spark.storage.StorageStatus - -/** - * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark - * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. - * - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI. - * @param sparkHome Location where Spark is installed on cluster nodes. - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - * @param environment Environment variables to set on worker nodes. - */ -class SparkContext( - val master: String, - val appName: String, - val sparkHome: String = null, - val jars: Seq[String] = Nil, - val environment: Map[String, String] = Map(), - // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too. - // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host - val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()) - extends Logging { - - // Ensure logging is initialized before we spawn any threads - initLogging() - - // Set Spark driver host and port system properties - if (System.getProperty("spark.driver.host") == null) { - System.setProperty("spark.driver.host", Utils.localHostName()) - } - if (System.getProperty("spark.driver.port") == null) { - System.setProperty("spark.driver.port", "0") - } - - val isLocal = (master == "local" || master.startsWith("local[")) - - // Create the Spark execution environment (cache, map output tracker, etc) - private[spark] val env = SparkEnv.createFromSystemProperties( - "", - System.getProperty("spark.driver.host"), - System.getProperty("spark.driver.port").toInt, - true, - isLocal) - SparkEnv.set(env) - - // Used to store a URL for each static file/jar together with the file's local timestamp - private[spark] val addedFiles = HashMap[String, Long]() - private[spark] val addedJars = HashMap[String, Long]() - - // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) - - // Initalize the Spark UI - private[spark] val ui = new SparkUI(this) - ui.bind() - - val startTime = System.currentTimeMillis() - - // Add each JAR given through the constructor - if (jars != null) { - jars.foreach { addJar(_) } - } - - // Environment variables to pass to our executors - private[spark] val executorEnvs = HashMap[String, String]() - // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner - for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) { - val value = System.getenv(key) - if (value != null) { - executorEnvs(key) = value - } - } - // Since memory can be set with a system property too, use that - executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m" - if (environment != null) { - executorEnvs ++= environment - } - - // Create and start the scheduler - private var taskScheduler: TaskScheduler = { - // Regular expression used for local[N] master format - val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r - // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r - // Regular expression for simulating a Spark cluster of [N, cores, memory] locally - val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r - // Regular expression for connecting to Spark deploy clusters - val SPARK_REGEX = """(spark://.*)""".r - //Regular expression for connection to Mesos cluster - val MESOS_REGEX = """(mesos://.*)""".r - - master match { - case "local" => - new LocalScheduler(1, 0, this) - - case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0, this) - - case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt, this) - - case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(this) - val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) - scheduler.initialize(backend) - scheduler - - case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. - val memoryPerSlaveInt = memoryPerSlave.toInt - if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { - throw new SparkException( - "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( - memoryPerSlaveInt, SparkContext.executorMemoryRequested)) - } - - val scheduler = new ClusterScheduler(this) - val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) - val sparkUrl = localCluster.start() - val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) - scheduler.initialize(backend) - backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { - localCluster.stop() - } - scheduler - - case "yarn-standalone" => - val scheduler = try { - val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler") - val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(this).asInstanceOf[ClusterScheduler] - } catch { - // TODO: Enumerate the exact reasons why it can fail - // But irrespective of it, it means we cannot proceed ! - case th: Throwable => { - throw new SparkException("YARN mode not available ?", th) - } - } - val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem) - scheduler.initialize(backend) - scheduler - - case _ => - if (MESOS_REGEX.findFirstIn(master).isEmpty) { - logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) - } - MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(this) - val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean - val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// - val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) - } else { - new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) - } - scheduler.initialize(backend) - scheduler - } - } - taskScheduler.start() - - @volatile private var dagScheduler = new DAGScheduler(taskScheduler) - dagScheduler.start() - - ui.start() - - /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ - val hadoopConfiguration = { - val env = SparkEnv.get - val conf = env.hadoop.newConfiguration() - // Explicitly check for S3 environment variables - if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { - conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) - conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) - } - // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { - conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) - } - val bufferSize = System.getProperty("spark.buffer.size", "65536") - conf.set("io.file.buffer.size", bufferSize) - conf - } - - private[spark] var checkpointDir: Option[String] = None - - // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new DynamicVariable[Properties](null) - - def initLocalProperties() { - localProperties.value = new Properties() - } - - def setLocalProperty(key: String, value: String) { - if (localProperties.value == null) { - localProperties.value = new Properties() - } - if (value == null) { - localProperties.value.remove(key) - } else { - localProperties.value.setProperty(key, value) - } - } - - /** Set a human readable description of the current job. */ - def setJobDescription(value: String) { - setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) - } - - // Post init - taskScheduler.postStartHook() - - val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) - val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) - - def initDriverMetrics() { - SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) - SparkEnv.get.metricsSystem.registerSource(blockManagerSource) - } - - initDriverMetrics() - - // Methods for creating RDDs - - /** Distribute a local Scala collection to form an RDD. */ - def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) - } - - /** Distribute a local Scala collection to form an RDD. */ - def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - parallelize(seq, numSlices) - } - - /** Distribute a local Scala collection to form an RDD, with one or more - * location preferences (hostnames of Spark nodes) for each object. - * Create a new partition for each collection item. */ - def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { - val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap - new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) - } - - /** - * Read a text file from HDFS, a local file system (available on all nodes), or any - * Hadoop-supported file system URI, and return it as an RDD of Strings. - */ - def textFile(path: String, minSplits: Int = defaultMinSplits): RDD[String] = { - hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minSplits) - .map(pair => pair._2.toString) - } - - /** - * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any - * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, - * etc). - */ - def hadoopRDD[K, V]( - conf: JobConf, - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int = defaultMinSplits - ): RDD[(K, V)] = { - new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) - } - - /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ - def hadoopFile[K, V]( - path: String, - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int = defaultMinSplits - ) : RDD[(K, V)] = { - val conf = new JobConf(hadoopConfiguration) - FileInputFormat.setInputPaths(conf, path) - new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) - } - - /** - * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, - * values and the InputFormat so that users don't need to pass them directly. Instead, callers - * can just write, for example, - * {{{ - * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits) - * }}} - */ - def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]) - : RDD[(K, V)] = { - hadoopFile(path, - fm.erasure.asInstanceOf[Class[F]], - km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]], - minSplits) - } - - /** - * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, - * values and the InputFormat so that users don't need to pass them directly. Instead, callers - * can just write, for example, - * {{{ - * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) - * }}} - */ - def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = - hadoopFile[K, V, F](path, defaultMinSplits) - - /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ - def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = { - newAPIHadoopFile( - path, - fm.erasure.asInstanceOf[Class[F]], - km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]]) - } - - /** - * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat - * and extra configuration options to pass to the input format. - */ - def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( - path: String, - fClass: Class[F], - kClass: Class[K], - vClass: Class[V], - conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { - val job = new NewHadoopJob(conf) - NewFileInputFormat.addInputPath(job, new Path(path)) - val updatedConf = job.getConfiguration - new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf) - } - - /** - * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat - * and extra configuration options to pass to the input format. - */ - def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( - conf: Configuration = hadoopConfiguration, - fClass: Class[F], - kClass: Class[K], - vClass: Class[V]): RDD[(K, V)] = { - new NewHadoopRDD(this, fClass, kClass, vClass, conf) - } - - /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ - def sequenceFile[K, V](path: String, - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int - ): RDD[(K, V)] = { - val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] - hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits) - } - - /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = - sequenceFile(path, keyClass, valueClass, defaultMinSplits) - - /** - * Version of sequenceFile() for types implicitly convertible to Writables through a - * WritableConverter. For example, to access a SequenceFile where the keys are Text and the - * values are IntWritable, you could simply write - * {{{ - * sparkContext.sequenceFile[String, Int](path, ...) - * }}} - * - * WritableConverters are provided in a somewhat strange way (by an implicit function) to support - * both subclasses of Writable and types for which we define a converter (e.g. Int to - * IntWritable). The most natural thing would've been to have implicit objects for the - * converters, but then we couldn't have an object for every subclass of Writable (you can't - * have a parameterized singleton object). We use functions instead to create a new converter - * for the appropriate type. In addition, we pass the converter a ClassManifest of its type to - * allow it to figure out the Writable class to use in the subclass case. - */ - def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits) - (implicit km: ClassManifest[K], vm: ClassManifest[V], - kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) - : RDD[(K, V)] = { - val kc = kcf() - val vc = vcf() - val format = classOf[SequenceFileInputFormat[Writable, Writable]] - val writables = hadoopFile(path, format, - kc.writableClass(km).asInstanceOf[Class[Writable]], - vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits) - writables.map{case (k,v) => (kc.convert(k), vc.convert(v))} - } - - /** - * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and - * BytesWritable values that contain a serialized partition. This is still an experimental storage - * format and may not be supported exactly as is in future Spark releases. It will also be pretty - * slow if you use the default serializer (Java serialization), though the nice thing about it is - * that there's very little effort required to save arbitrary objects. - */ - def objectFile[T: ClassManifest]( - path: String, - minSplits: Int = defaultMinSplits - ): RDD[T] = { - sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits) - .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes)) - } - - - protected[spark] def checkpointFile[T: ClassManifest]( - path: String - ): RDD[T] = { - new CheckpointRDD[T](this, path) - } - - /** Build the union of a list of RDDs. */ - def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) - - /** Build the union of a list of RDDs passed as variable-length arguments. */ - def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] = - new UnionRDD(this, Seq(first) ++ rest) - - // Methods for creating shared variables - - /** - * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the driver can access the accumulator's `value`. - */ - def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = - new Accumulator(initialValue, param) - - /** - * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. - * Only the driver can access the accumuable's `value`. - * @tparam T accumulator type - * @tparam R type that can be added to the accumulator - */ - def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = - new Accumulable(initialValue, param) - - /** - * Create an accumulator from a "mutable collection" type. - * - * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by - * standard mutable collections. So you can use this with mutable Map, Set, etc. - */ - def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = { - val param = new GrowableAccumulableParam[R,T] - new Accumulable(initialValue, param) - } - - /** - * Broadcast a read-only variable to the cluster, returning a [[spark.broadcast.Broadcast]] object for - * reading it in distributed functions. The variable will be sent to each cluster only once. - */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) - - /** - * Add a file to be downloaded with this Spark job on every node. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, - * use `SparkFiles.get(path)` to find its download location. - */ - def addFile(path: String) { - val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) - case _ => path - } - addedFiles(key) = System.currentTimeMillis - - // Fetch the file locally in case a job is executed locally. - // Jobs that run through LocalScheduler will already fetch the required dependencies, - // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) - - logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) - } - - def addSparkListener(listener: SparkListener) { - dagScheduler.addSparkListener(listener) - } - - /** - * Return a map from the slave to the max memory available for caching and the remaining - * memory available for caching. - */ - def getExecutorMemoryStatus: Map[String, (Long, Long)] = { - env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => - (blockManagerId.host + ":" + blockManagerId.port, mem) - } - } - - /** - * Return information about what RDDs are cached, if they are in mem or on disk, how much space - * they take, etc. - */ - def getRDDStorageInfo: Array[RDDInfo] = { - StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) - } - - /** - * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. - */ - def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - - def getStageInfo: Map[Stage,StageInfo] = { - dagScheduler.stageToInfos - } - - /** - * Return information about blocks stored in all of the slaves - */ - def getExecutorStorageStatus: Array[StorageStatus] = { - env.blockManager.master.getStorageStatus - } - - /** - * Return pools for fair scheduler - * TODO(xiajunluan): We should take nested pools into account - */ - def getAllPools: ArrayBuffer[Schedulable] = { - taskScheduler.rootPool.schedulableQueue - } - - /** - * Return the pool associated with the given name, if one exists - */ - def getPoolForName(pool: String): Option[Schedulable] = { - taskScheduler.rootPool.schedulableNameToSchedulable.get(pool) - } - - /** - * Return current scheduling mode - */ - def getSchedulingMode: SchedulingMode.SchedulingMode = { - taskScheduler.schedulingMode - } - - /** - * Clear the job's list of files added by `addFile` so that they do not get downloaded to - * any new nodes. - */ - def clearFiles() { - addedFiles.clear() - } - - /** - * Gets the locality information associated with the partition in a particular rdd - * @param rdd of interest - * @param partition to be looked up for locality - * @return list of preferred locations for the partition - */ - private [spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { - dagScheduler.getPreferredLocs(rdd, partition) - } - - /** - * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. - */ - def addJar(path: String) { - if (null == path) { - logWarning("null specified as parameter to addJar", - new SparkException("null specified as parameter to addJar")) - } else { - val env = SparkEnv.get - val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => - if (env.hadoop.isYarnMode()) { - logWarning("local jar specified as parameter to addJar under Yarn mode") - return - } - env.httpFileServer.addJar(new File(uri.getPath)) - case _ => path - } - addedJars(key) = System.currentTimeMillis - logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) - } - } - - /** - * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to - * any new nodes. - */ - def clearJars() { - addedJars.clear() - } - - /** Shut down the SparkContext. */ - def stop() { - ui.stop() - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { - metadataCleaner.cancel() - dagSchedulerCopy.stop() - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - // Clean up locally linked files - clearFiles() - clearJars() - SparkEnv.set(null) - ShuffleMapTask.clearCache() - ResultTask.clearCache() - logInfo("Successfully stopped SparkContext") - } else { - logInfo("SparkContext already stopped") - } - } - - - /** - * Get Spark's home location from either a value set through the constructor, - * or the spark.home Java property, or the SPARK_HOME environment variable - * (in that order of preference). If neither of these is set, return None. - */ - private[spark] def getSparkHome(): Option[String] = { - if (sparkHome != null) { - Some(sparkHome) - } else if (System.getProperty("spark.home") != null) { - Some(System.getProperty("spark.home")) - } else if (System.getenv("SPARK_HOME") != null) { - Some(System.getenv("SPARK_HOME")) - } else { - None - } - } - - /** - * Run a function on a given set of partitions in an RDD and pass the results to the given - * handler function. This is the main entry point for all actions in Spark. The allowLocal - * flag specifies whether the scheduler can run the computation on the driver rather than - * shipping it out to the cluster, for short actions like first(). - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean, - resultHandler: (Int, U) => Unit) { - val callSite = Utils.formatSparkCallSite - logInfo("Starting job: " + callSite) - val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value) - logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") - rdd.doCheckpoint() - result - } - - /** - * Run a function on a given set of partitions in an RDD and return the results as an array. The - * allowLocal flag specifies whether the scheduler can run the computation on the driver rather - * than shipping it out to the cluster, for short actions like first(). - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean - ): Array[U] = { - val results = new Array[U](partitions.size) - runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) - results - } - - /** - * Run a job on a given set of partitions of an RDD, but take a function of type - * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - func: Iterator[T] => U, - partitions: Seq[Int], - allowLocal: Boolean - ): Array[U] = { - runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal) - } - - /** - * Run a job on all partitions in an RDD and return the results in an array. - */ - def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) - } - - /** - * Run a job on all partitions in an RDD and return the results in an array. - */ - def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) - } - - /** - * Run a job on all partitions in an RDD and pass the results to a handler function. - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - processPartition: (TaskContext, Iterator[T]) => U, - resultHandler: (Int, U) => Unit) - { - runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler) - } - - /** - * Run a job on all partitions in an RDD and pass the results to a handler function. - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - processPartition: Iterator[T] => U, - resultHandler: (Int, U) => Unit) - { - val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) - runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) - } - - /** - * Run a job that can return approximate results. - */ - def runApproximateJob[T, U, R]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - evaluator: ApproximateEvaluator[U, R], - timeout: Long): PartialResult[R] = { - val callSite = Utils.formatSparkCallSite - logInfo("Starting job: " + callSite) - val start = System.nanoTime - val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value) - logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") - result - } - - /** - * Clean a closure to make it ready to serialized and send to tasks - * (removes unreferenced variables in $outer's, updates REPL variables) - */ - private[spark] def clean[F <: AnyRef](f: F): F = { - ClosureCleaner.clean(f) - return f - } - - /** - * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists and useExisting is set to true, then the - * exisiting directory will be used. Otherwise an exception will be thrown to - * prevent accidental overriding of checkpoint files in the existing directory. - */ - def setCheckpointDir(dir: String, useExisting: Boolean = false) { - val env = SparkEnv.get - val path = new Path(dir) - val fs = path.getFileSystem(env.hadoop.newConfiguration()) - if (!useExisting) { - if (fs.exists(path)) { - throw new Exception("Checkpoint directory '" + path + "' already exists.") - } else { - fs.mkdirs(path) - } - } - checkpointDir = Some(dir) - } - - /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ - def defaultParallelism: Int = taskScheduler.defaultParallelism - - /** Default min number of partitions for Hadoop RDDs when not given by user */ - def defaultMinSplits: Int = math.min(defaultParallelism, 2) - - private val nextShuffleId = new AtomicInteger(0) - - private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement() - - private val nextRddId = new AtomicInteger(0) - - /** Register a new RDD, returning its RDD ID */ - private[spark] def newRddId(): Int = nextRddId.getAndIncrement() - - /** Called by MetadataCleaner to clean up the persistentRdds map periodically */ - private[spark] def cleanup(cleanupTime: Long) { - persistentRdds.clearOldValues(cleanupTime) - } -} - -/** - * The SparkContext object contains a number of implicit conversions and parameters for use with - * various Spark features. - */ -object SparkContext { - val SPARK_JOB_DESCRIPTION = "spark.job.description" - - implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { - def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 - } - - implicit object IntAccumulatorParam extends AccumulatorParam[Int] { - def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 - } - - implicit object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0l - } - - implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f - } - - // TODO: Add AccumulatorParams for other types, e.g. lists and strings - - implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = - new PairRDDFunctions(rdd) - - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest]( - rdd: RDD[(K, V)]) = - new SequenceFileRDDFunctions(rdd) - - implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( - rdd: RDD[(K, V)]) = - new OrderedRDDFunctions[K, V, (K, V)](rdd) - - implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) - - implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = - new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) - - // Implicit conversions to common Writable types, for saveAsSequenceFile - - implicit def intToIntWritable(i: Int) = new IntWritable(i) - - implicit def longToLongWritable(l: Long) = new LongWritable(l) - - implicit def floatToFloatWritable(f: Float) = new FloatWritable(f) - - implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d) - - implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b) - - implicit def bytesToBytesWritable (aob: Array[Byte]) = new BytesWritable(aob) - - implicit def stringToText(s: String) = new Text(s) - - private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = { - def anyToWritable[U <% Writable](u: U): Writable = u - - new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]], - arr.map(x => anyToWritable(x)).toArray) - } - - // Helper objects for converting common types to Writable - private def simpleWritableConverter[T, W <: Writable: ClassManifest](convert: W => T) = { - val wClass = classManifest[W].erasure.asInstanceOf[Class[W]] - new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) - } - - implicit def intWritableConverter() = simpleWritableConverter[Int, IntWritable](_.get) - - implicit def longWritableConverter() = simpleWritableConverter[Long, LongWritable](_.get) - - implicit def doubleWritableConverter() = simpleWritableConverter[Double, DoubleWritable](_.get) - - implicit def floatWritableConverter() = simpleWritableConverter[Float, FloatWritable](_.get) - - implicit def booleanWritableConverter() = simpleWritableConverter[Boolean, BooleanWritable](_.get) - - implicit def bytesWritableConverter() = simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) - - implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString) - - implicit def writableWritableConverter[T <: Writable]() = - new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T]) - - /** - * Find the JAR from which a given class was loaded, to make it easy for users to pass - * their JARs to SparkContext - */ - def jarOfClass(cls: Class[_]): Seq[String] = { - val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class") - if (uri != null) { - val uriStr = uri.toString - if (uriStr.startsWith("jar:file:")) { - // URI will be of the form "jar:file:/path/foo.jar!/package/cls.class", so pull out the /path/foo.jar - List(uriStr.substring("jar:file:".length, uriStr.indexOf('!'))) - } else { - Nil - } - } else { - Nil - } - } - - /** Find the JAR that contains the class of a particular object */ - def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) - - /** Get the amount of memory per executor requested through system properties or SPARK_MEM */ - private[spark] val executorMemoryRequested = { - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - Option(System.getProperty("spark.executor.memory")) - .orElse(Option(System.getenv("SPARK_MEM"))) - .map(Utils.memoryStringToMb) - .getOrElse(512) - } -} - -/** - * A class encapsulating how to convert some type T to Writable. It stores both the Writable class - * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. - * The getter for the writable class takes a ClassManifest[T] in case this is a generic object - * that doesn't know the type of T when it is created. This sounds strange but is necessary to - * support converting subclasses of Writable to themselves (writableWritableConverter). - */ -private[spark] class WritableConverter[T]( - val writableClass: ClassManifest[T] => Class[_ <: Writable], - val convert: Writable => T) - extends Serializable - diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala deleted file mode 100644 index 1f66e9cc7f..0000000000 --- a/core/src/main/scala/spark/SparkEnv.scala +++ /dev/null @@ -1,241 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import collection.mutable -import serializer.Serializer - -import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} -import akka.remote.RemoteActorRefProvider - -import spark.broadcast.BroadcastManager -import spark.metrics.MetricsSystem -import spark.deploy.SparkHadoopUtil -import spark.storage.BlockManager -import spark.storage.BlockManagerMaster -import spark.network.ConnectionManager -import spark.serializer.{Serializer, SerializerManager} -import spark.util.AkkaUtils -import spark.api.python.PythonWorkerFactory - - -/** - * Holds all the runtime environment objects for a running Spark instance (either master or worker), - * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently - * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these - * objects needs to have the right SparkEnv set. You can get the current environment with - * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. - */ -class SparkEnv ( - val executorId: String, - val actorSystem: ActorSystem, - val serializerManager: SerializerManager, - val serializer: Serializer, - val closureSerializer: Serializer, - val cacheManager: CacheManager, - val mapOutputTracker: MapOutputTracker, - val shuffleFetcher: ShuffleFetcher, - val broadcastManager: BroadcastManager, - val blockManager: BlockManager, - val connectionManager: ConnectionManager, - val httpFileServer: HttpFileServer, - val sparkFilesDir: String, - val metricsSystem: MetricsSystem) { - - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() - - val hadoop = { - val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) - if(yarnMode) { - try { - Class.forName("spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil] - } catch { - case th: Throwable => throw new SparkException("Unable to load YARN support", th) - } - } else { - new SparkHadoopUtil - } - } - - def stop() { - pythonWorkers.foreach { case(key, worker) => worker.stop() } - httpFileServer.stop() - mapOutputTracker.stop() - shuffleFetcher.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - actorSystem.shutdown() - // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut - // down, but let's call it anyway in case it gets fixed in a later release - actorSystem.awaitTermination() - } - - def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { - synchronized { - val key = (pythonExec, envVars) - pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() - } - } -} - -object SparkEnv extends Logging { - private val env = new ThreadLocal[SparkEnv] - @volatile private var lastSetSparkEnv : SparkEnv = _ - - def set(e: SparkEnv) { - lastSetSparkEnv = e - env.set(e) - } - - /** - * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv - * previously set in any thread. - */ - def get: SparkEnv = { - Option(env.get()).getOrElse(lastSetSparkEnv) - } - - /** - * Returns the ThreadLocal SparkEnv. - */ - def getThreadLocal : SparkEnv = { - env.get() - } - - def createFromSystemProperties( - executorId: String, - hostname: String, - port: Int, - isDriver: Boolean, - isLocal: Boolean): SparkEnv = { - - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) - - // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), - // figure out which port number Akka actually bound to and set spark.driver.port to it. - if (isDriver && port == 0) { - System.setProperty("spark.driver.port", boundPort.toString) - } - - // set only if unset until now. - if (System.getProperty("spark.hostPort", null) == null) { - if (!isDriver){ - // unexpected - Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set") - } - Utils.checkHost(hostname) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) - } - - val classLoader = Thread.currentThread.getContextClassLoader - - // Create an instance of the class named by the given Java system property, or by - // defaultClassName if the property is not set, and return it as a T - def instantiateClass[T](propertyName: String, defaultClassName: String): T = { - val name = System.getProperty(propertyName, defaultClassName) - Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] - } - - val serializerManager = new SerializerManager - - val serializer = serializerManager.setDefault( - System.getProperty("spark.serializer", "spark.JavaSerializer")) - - val closureSerializer = serializerManager.get( - System.getProperty("spark.closure.serializer", "spark.JavaSerializer")) - - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { - if (isDriver) { - logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) - } else { - val driverHost: String = System.getProperty("spark.driver.host", "localhost") - val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt - Utils.checkHost(driverHost, "Expected hostname") - val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name) - logInfo("Connecting to " + name + ": " + url) - actorSystem.actorFor(url) - } - } - - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( - "BlockManagerMaster", - new spark.storage.BlockManagerMasterActor(isLocal))) - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer) - - val connectionManager = blockManager.connectionManager - - val broadcastManager = new BroadcastManager(isDriver) - - val cacheManager = new CacheManager(blockManager) - - // Have to assign trackerActor after initialization as MapOutputTrackerActor - // requires the MapOutputTracker itself - val mapOutputTracker = new MapOutputTracker() - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerActor(mapOutputTracker)) - - val shuffleFetcher = instantiateClass[ShuffleFetcher]( - "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") - - val httpFileServer = new HttpFileServer() - httpFileServer.initialize() - System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) - - val metricsSystem = if (isDriver) { - MetricsSystem.createMetricsSystem("driver") - } else { - MetricsSystem.createMetricsSystem("executor") - } - metricsSystem.start() - - // Set the sparkFiles directory, used when downloading dependencies. In local mode, - // this is a temporary directory; in distributed mode, this is the executor's current working - // directory. - val sparkFilesDir: String = if (isDriver) { - Utils.createTempDir().getAbsolutePath - } else { - "." - } - - // Warn about deprecated spark.cache.class property - if (System.getProperty("spark.cache.class") != null) { - logWarning("The spark.cache.class property is no longer being used! Specify storage " + - "levels using the RDD.persist() method instead.") - } - - new SparkEnv( - executorId, - actorSystem, - serializerManager, - serializer, - closureSerializer, - cacheManager, - mapOutputTracker, - shuffleFetcher, - broadcastManager, - blockManager, - connectionManager, - httpFileServer, - sparkFilesDir, - metricsSystem) - } -} diff --git a/core/src/main/scala/spark/SparkException.scala b/core/src/main/scala/spark/SparkException.scala deleted file mode 100644 index b7045eea63..0000000000 --- a/core/src/main/scala/spark/SparkException.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -class SparkException(message: String, cause: Throwable) - extends Exception(message, cause) { - - def this(message: String) = this(message, null) -} diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java deleted file mode 100644 index f9b3f7965e..0000000000 --- a/core/src/main/scala/spark/SparkFiles.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark; - -import java.io.File; - -/** - * Resolves paths to files added through `SparkContext.addFile()`. - */ -public class SparkFiles { - - private SparkFiles() {} - - /** - * Get the absolute path of a file added through `SparkContext.addFile()`. - */ - public static String get(String filename) { - return new File(getRootDirectory(), filename).getAbsolutePath(); - } - - /** - * Get the root directory that contains files added through `SparkContext.addFile()`. - */ - public static String getRootDirectory() { - return SparkEnv.get().sparkFilesDir(); - } -} diff --git a/core/src/main/scala/spark/SparkHadoopWriter.scala b/core/src/main/scala/spark/SparkHadoopWriter.scala deleted file mode 100644 index 6b330ef572..0000000000 --- a/core/src/main/scala/spark/SparkHadoopWriter.scala +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hadoop.mapred - -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path - -import java.text.SimpleDateFormat -import java.text.NumberFormat -import java.io.IOException -import java.util.Date - -import spark.Logging -import spark.SerializableWritable - -/** - * Internal helper class that saves an RDD using a Hadoop OutputFormat. This is only public - * because we need to access this class from the `spark` package to use some package-private Hadoop - * functions, but this class should not be used directly by users. - * - * Saves the RDD using a JobConf, which should contain an output key class, an output value class, - * a filename to write to, etc, exactly like in a Hadoop MapReduce job. - */ -class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { - - private val now = new Date() - private val conf = new SerializableWritable(jobConf) - - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null - - @transient private var writer: RecordWriter[AnyRef,AnyRef] = null - @transient private var format: OutputFormat[AnyRef,AnyRef] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null - - def preSetup() { - setIDs(0, 0, 0) - setConfParams() - - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) - } - - - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - setConfParams() - } - - def open() { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val outputName = "part-" + numfmt.format(splitID) - val path = FileOutputFormat.getOutputPath(conf.value) - val fs: FileSystem = { - if (path != null) { - path.getFileSystem(conf.value) - } else { - FileSystem.get(conf.value) - } - } - - getOutputCommitter().setupTask(getTaskContext()) - writer = getOutputFormat().getRecordWriter( - fs, conf.value, outputName, Reporter.NULL) - } - - def write(key: AnyRef, value: AnyRef) { - if (writer!=null) { - //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")") - writer.write(key, value) - } else { - throw new IOException("Writer is null, open() has not been called") - } - } - - def close() { - writer.close(Reporter.NULL) - } - - def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => { - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } - } else { - logWarning ("No need to commit output of task: " + taID.value) - } - } - - def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) - } - - def cleanup() { - getOutputCommitter().cleanupJob(getJobContext()) - } - - // ********* Private Functions ********* - - private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef,AnyRef]] - } - return format - } - - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter - } - return committer - } - - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = newJobContext(conf.value, jID.value) - } - return jobContext - } - - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) - } - return taskContext - } - - private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { - jobID = jobid - splitID = splitid - attemptID = attemptid - - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) - } - - private def setConfParams() { - conf.value.set("mapred.job.id", jID.value.toString) - conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) - conf.value.set("mapred.task.id", taID.value.toString) - conf.value.setBoolean("mapred.task.is.map", true) - conf.value.setInt("mapred.task.partition", splitID) - } -} - -object SparkHadoopWriter { - def createJobID(time: Date, id: Int): JobID = { - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - val jobtrackerID = formatter.format(new Date()) - return new JobID(jobtrackerID, id) - } - - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - var outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath = outputPath.makeQualified(fs) - return outputPath - } -} diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala deleted file mode 100644 index b79f4ca813..0000000000 --- a/core/src/main/scala/spark/TaskContext.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import executor.TaskMetrics -import scala.collection.mutable.ArrayBuffer - -class TaskContext( - val stageId: Int, - val splitId: Int, - val attemptId: Long, - val taskMetrics: TaskMetrics = TaskMetrics.empty() -) extends Serializable { - - @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] - - // Add a callback function to be executed on task completion. An example use - // is for HadoopRDD to register a callback to close the input stream. - def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += f - } - - def executeOnCompleteCallbacks() { - onCompleteCallbacks.foreach{_()} - } -} diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala deleted file mode 100644 index 3ad665da34..0000000000 --- a/core/src/main/scala/spark/TaskEndReason.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import spark.executor.TaskMetrics -import spark.storage.BlockManagerId - -/** - * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry - * tasks several times for "ephemeral" failures, and only report back failures that require some - * old stages to be resubmitted, such as shuffle map fetch failures. - */ -private[spark] sealed trait TaskEndReason - -private[spark] case object Success extends TaskEndReason - -private[spark] -case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it - -private[spark] case class FetchFailed( - bmAddress: BlockManagerId, - shuffleId: Int, - mapId: Int, - reduceId: Int) - extends TaskEndReason - -private[spark] case class ExceptionFailure( - className: String, - description: String, - stackTrace: Array[StackTraceElement], - metrics: Option[TaskMetrics]) - extends TaskEndReason - -private[spark] case class OtherFailure(message: String) extends TaskEndReason - -private[spark] case class TaskResultTooBigFailure() extends TaskEndReason diff --git a/core/src/main/scala/spark/TaskState.scala b/core/src/main/scala/spark/TaskState.scala deleted file mode 100644 index bf75753056..0000000000 --- a/core/src/main/scala/spark/TaskState.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.apache.mesos.Protos.{TaskState => MesosTaskState} - -private[spark] object TaskState - extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") { - - val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value - - val FINISHED_STATES = Set(FINISHED, FAILED, KILLED, LOST) - - type TaskState = Value - - def isFinished(state: TaskState) = FINISHED_STATES.contains(state) - - def toMesos(state: TaskState): MesosTaskState = state match { - case LAUNCHING => MesosTaskState.TASK_STARTING - case RUNNING => MesosTaskState.TASK_RUNNING - case FINISHED => MesosTaskState.TASK_FINISHED - case FAILED => MesosTaskState.TASK_FAILED - case KILLED => MesosTaskState.TASK_KILLED - case LOST => MesosTaskState.TASK_LOST - } - - def fromMesos(mesosState: MesosTaskState): TaskState = mesosState match { - case MesosTaskState.TASK_STAGING => LAUNCHING - case MesosTaskState.TASK_STARTING => LAUNCHING - case MesosTaskState.TASK_RUNNING => RUNNING - case MesosTaskState.TASK_FINISHED => FINISHED - case MesosTaskState.TASK_FAILED => FAILED - case MesosTaskState.TASK_KILLED => KILLED - case MesosTaskState.TASK_LOST => LOST - } -} diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala deleted file mode 100644 index bb8aad3f4c..0000000000 --- a/core/src/main/scala/spark/Utils.scala +++ /dev/null @@ -1,780 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ -import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} -import java.util.{Locale, Random, UUID} -import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} -import java.util.regex.Pattern - -import scala.collection.Map -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.collection.JavaConversions._ -import scala.io.Source - -import com.google.common.io.Files -import com.google.common.util.concurrent.ThreadFactoryBuilder - -import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} - -import spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import spark.deploy.SparkHadoopUtil -import java.nio.ByteBuffer - - -/** - * Various utility methods used by Spark. - */ -private object Utils extends Logging { - - /** Serialize an object using Java serialization */ - def serialize[T](o: T): Array[Byte] = { - val bos = new ByteArrayOutputStream() - val oos = new ObjectOutputStream(bos) - oos.writeObject(o) - oos.close() - return bos.toByteArray - } - - /** Deserialize an object using Java serialization */ - def deserialize[T](bytes: Array[Byte]): T = { - val bis = new ByteArrayInputStream(bytes) - val ois = new ObjectInputStream(bis) - return ois.readObject.asInstanceOf[T] - } - - /** Deserialize an object using Java serialization and the given ClassLoader */ - def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { - val bis = new ByteArrayInputStream(bytes) - val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } - return ois.readObject.asInstanceOf[T] - } - - /** Serialize via nested stream using specific serializer */ - def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = { - val osWrapper = ser.serializeStream(new OutputStream { - def write(b: Int) = os.write(b) - - override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) - }) - try { - f(osWrapper) - } finally { - osWrapper.close() - } - } - - /** Deserialize via nested stream using specific serializer */ - def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(f: DeserializationStream => Unit) = { - val isWrapper = ser.deserializeStream(new InputStream { - def read(): Int = is.read() - - override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) - }) - try { - f(isWrapper) - } finally { - isWrapper.close() - } - } - - /** - * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. - */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { - if (bb.hasArray) { - out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) - } else { - val bbval = new Array[Byte](bb.remaining()) - bb.get(bbval) - out.write(bbval) - } - } - - def isAlpha(c: Char): Boolean = { - (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') - } - - /** Split a string into words at non-alphabetic characters */ - def splitWords(s: String): Seq[String] = { - val buf = new ArrayBuffer[String] - var i = 0 - while (i < s.length) { - var j = i - while (j < s.length && isAlpha(s.charAt(j))) { - j += 1 - } - if (j > i) { - buf += s.substring(i, j) - } - i = j - while (i < s.length && !isAlpha(s.charAt(i))) { - i += 1 - } - } - return buf - } - - private val shutdownDeletePaths = new collection.mutable.HashSet[String]() - - // Register the path to be deleted via shutdown hook - def registerShutdownDeleteDir(file: File) { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths += absolutePath - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) - } - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in IOException and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - val retval = shutdownDeletePaths.synchronized { - shutdownDeletePaths.find { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - }.isDefined - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - - /** Create a temporary directory inside the given parent directory */ - def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { - var attempts = 0 - val maxAttempts = 10 - var dir: File = null - while (dir == null) { - attempts += 1 - if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory (under " + root + ") after " + - maxAttempts + " attempts!") - } - try { - dir = new File(root, "spark-" + UUID.randomUUID.toString) - if (dir.exists() || !dir.mkdirs()) { - dir = null - } - } catch { case e: IOException => ; } - } - - registerShutdownDeleteDir(dir) - - // Add a shutdown hook to delete the temp dir when the JVM exits - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { - override def run() { - // Attempt to delete if some patch which is parent of this is not already registered. - if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) - } - }) - dir - } - - /** Copy all data from an InputStream to an OutputStream */ - def copyStream(in: InputStream, - out: OutputStream, - closeStreams: Boolean = false) - { - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) - } - } - if (closeStreams) { - in.close() - out.close() - } - } - - /** - * Download a file requested by the executor. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. - * - * Throws SparkException if the target file already exists and has different contents than - * the requested file. - */ - def fetchFile(url: String, targetDir: File) { - val filename = url.split("/").last - val tempDir = getLocalDir - val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) - val targetFile = new File(targetDir, filename) - val uri = new URI(url) - uri.getScheme match { - case "http" | "https" | "ftp" => - logInfo("Fetching " + url + " to " + tempFile) - val in = new URL(url).openStream() - val out = new FileOutputStream(tempFile) - Utils.copyStream(in, out, true) - if (targetFile.exists && !Files.equal(tempFile, targetFile)) { - tempFile.delete() - throw new SparkException( - "File " + targetFile + " exists and does not match contents of" + " " + url) - } else { - Files.move(tempFile, targetFile) - } - case "file" | null => - // In the case of a local file, copy the local file to the target directory. - // Note the difference between uri vs url. - val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) - if (targetFile.exists) { - // If the target file already exists, warn the user if - if (!Files.equal(sourceFile, targetFile)) { - throw new SparkException( - "File " + targetFile + " exists and does not match contents of" + " " + url) - } else { - // Do nothing if the file contents are the same, i.e. this file has been copied - // previously. - logInfo(sourceFile.getAbsolutePath + " has been previously copied to " - + targetFile.getAbsolutePath) - } - } else { - // The file does not exist in the target directory. Copy it there. - logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) - Files.copy(sourceFile, targetFile) - } - case _ => - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others - val env = SparkEnv.get - val uri = new URI(url) - val conf = env.hadoop.newConfiguration() - val fs = FileSystem.get(uri, conf) - val in = fs.open(new Path(uri)) - val out = new FileOutputStream(tempFile) - Utils.copyStream(in, out, true) - if (targetFile.exists && !Files.equal(tempFile, targetFile)) { - tempFile.delete() - throw new SparkException("File " + targetFile + " exists and does not match contents of" + - " " + url) - } else { - Files.move(tempFile, targetFile) - } - } - // Decompress the file if it's a .tar or .tar.gz - if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xzf", filename), targetDir) - } else if (filename.endsWith(".tar")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xf", filename), targetDir) - } - // Make the file executable - That's necessary for scripts - FileUtil.chmod(targetFile.getAbsolutePath, "a+x") - } - - /** - * Get a temporary directory using Spark's spark.local.dir property, if set. This will always - * return a single directory, even though the spark.local.dir property might be a list of - * multiple paths. - */ - def getLocalDir: String = { - System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) - } - - /** - * Shuffle the elements of a collection into a random order, returning the - * result in a new collection. Unlike scala.util.Random.shuffle, this method - * uses a local random number generator, avoiding inter-thread contention. - */ - def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = { - randomizeInPlace(seq.toArray) - } - - /** - * Shuffle the elements of an array into a random order, modifying the - * original array. Returns the original array. - */ - def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { - for (i <- (arr.length - 1) to 1 by -1) { - val j = rand.nextInt(i) - val tmp = arr(j) - arr(j) = arr(i) - arr(i) = tmp - } - arr - } - - /** - * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). - * Note, this is typically not used from within core spark. - */ - lazy val localIpAddress: String = findLocalIpAddress() - lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) - - private def findLocalIpAddress(): String = { - val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") - if (defaultIpOverride != null) { - defaultIpOverride - } else { - val address = InetAddress.getLocalHost - if (address.isLoopbackAddress) { - // Address resolves to something like 127.0.1.1, which happens on Debian; try to find - // a better address using the local network interfaces - for (ni <- NetworkInterface.getNetworkInterfaces) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && - !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { - // We've found an address that looks reasonable! - logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + - " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + - " instead (on interface " + ni.getName + ")") - logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") - return addr.getHostAddress - } - } - logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + - " a loopback address: " + address.getHostAddress + ", but we couldn't find any" + - " external IP address!") - logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") - } - address.getHostAddress - } - } - - private var customHostname: Option[String] = None - - /** - * Allow setting a custom host name because when we run on Mesos we need to use the same - * hostname it reports to the master. - */ - def setCustomHostname(hostname: String) { - // DEBUG code - Utils.checkHost(hostname) - customHostname = Some(hostname) - } - - /** - * Get the local machine's hostname. - */ - def localHostName(): String = { - customHostname.getOrElse(localIpAddressHostname) - } - - def getAddressHostName(address: String): String = { - InetAddress.getByName(address).getHostName - } - - def localHostPort(): String = { - val retval = System.getProperty("spark.hostPort", null) - if (retval == null) { - logErrorWithStack("spark.hostPort not set but invoking localHostPort") - return localHostName() - } - - retval - } - - def checkHost(host: String, message: String = "") { - assert(host.indexOf(':') == -1, message) - } - - def checkHostPort(hostPort: String, message: String = "") { - assert(hostPort.indexOf(':') != -1, message) - } - - // Used by DEBUG code : remove when all testing done - def logErrorWithStack(msg: String) { - try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } - } - - // Typically, this will be of order of number of nodes in cluster - // If not, we should change it to LRUCache or something. - private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() - - def parseHostPort(hostPort: String): (String, Int) = { - { - // Check cache first. - var cached = hostPortParseResults.get(hostPort) - if (cached != null) return cached - } - - val indx: Int = hostPort.lastIndexOf(':') - // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... - // but then hadoop does not support ipv6 right now. - // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 - if (-1 == indx) { - val retval = (hostPort, 0) - hostPortParseResults.put(hostPort, retval) - return retval - } - - val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) - hostPortParseResults.putIfAbsent(hostPort, retval) - hostPortParseResults.get(hostPort) - } - - private[spark] val daemonThreadFactory: ThreadFactory = - new ThreadFactoryBuilder().setDaemon(true).build() - - /** - * Wrapper over newCachedThreadPool. - */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = - Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] - - /** - * Return the string to tell how long has passed in seconds. The passing parameter should be in - * millisecond. - */ - def getUsedTimeMs(startTimeMs: Long): String = { - return " " + (System.currentTimeMillis - startTimeMs) + " ms" - } - - /** - * Wrapper over newFixedThreadPool. - */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = - Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] - - /** - * Delete a file or directory and its contents recursively. - */ - def deleteRecursively(file: File) { - if (file.isDirectory) { - for (child <- file.listFiles()) { - deleteRecursively(child) - } - } - if (!file.delete()) { - throw new IOException("Failed to delete: " + file) - } - } - - /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. - * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM - * environment variable. - */ - def memoryStringToMb(str: String): Int = { - val lower = str.toLowerCase - if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt - } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt - } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 - } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes - (lower.toLong / 1024 / 1024).toInt - } - } - - /** - * Convert a quantity in bytes to a human-readable string such as "4.0 MB". - */ - def bytesToString(size: Long): String = { - val TB = 1L << 40 - val GB = 1L << 30 - val MB = 1L << 20 - val KB = 1L << 10 - - val (value, unit) = { - if (size >= 2*TB) { - (size.asInstanceOf[Double] / TB, "TB") - } else if (size >= 2*GB) { - (size.asInstanceOf[Double] / GB, "GB") - } else if (size >= 2*MB) { - (size.asInstanceOf[Double] / MB, "MB") - } else if (size >= 2*KB) { - (size.asInstanceOf[Double] / KB, "KB") - } else { - (size.asInstanceOf[Double], "B") - } - } - "%.1f %s".formatLocal(Locale.US, value, unit) - } - - /** - * Returns a human-readable string representing a duration such as "35ms" - */ - def msDurationToString(ms: Long): String = { - val second = 1000 - val minute = 60 * second - val hour = 60 * minute - - ms match { - case t if t < second => - "%d ms".format(t) - case t if t < minute => - "%.1f s".format(t.toFloat / second) - case t if t < hour => - "%.1f m".format(t.toFloat / minute) - case t => - "%.2f h".format(t.toFloat / hour) - } - } - - /** - * Convert a quantity in megabytes to a human-readable string such as "4.0 MB". - */ - def megabytesToString(megabytes: Long): String = { - bytesToString(megabytes * 1024L * 1024L) - } - - /** - * Execute a command in the given working directory, throwing an exception if it completes - * with an exit code other than 0. - */ - def execute(command: Seq[String], workingDir: File) { - val process = new ProcessBuilder(command: _*) - .directory(workingDir) - .redirectErrorStream(true) - .start() - new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines) { - System.err.println(line) - } - } - }.start() - val exitCode = process.waitFor() - if (exitCode != 0) { - throw new SparkException("Process " + command + " exited with code " + exitCode) - } - } - - /** - * Execute a command in the current working directory, throwing an exception if it completes - * with an exit code other than 0. - */ - def execute(command: Seq[String]) { - execute(command, new File(".")) - } - - /** - * Execute a command and get its output, throwing an exception if it yields a code other than 0. - */ - def executeAndGetOutput(command: Seq[String], workingDir: File = new File("."), - extraEnvironment: Map[String, String] = Map.empty): String = { - val builder = new ProcessBuilder(command: _*) - .directory(workingDir) - val environment = builder.environment() - for ((key, value) <- extraEnvironment) { - environment.put(key, value) - } - val process = builder.start() - new Thread("read stderr for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getErrorStream).getLines) { - System.err.println(line) - } - } - }.start() - val output = new StringBuffer - val stdoutThread = new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines) { - output.append(line) - } - } - } - stdoutThread.start() - val exitCode = process.waitFor() - stdoutThread.join() // Wait for it to finish reading output - if (exitCode != 0) { - throw new SparkException("Process " + command + " exited with code " + exitCode) - } - output.toString - } - - /** - * A regular expression to match classes of the "core" Spark API that we want to skip when - * finding the call site of a method. - */ - private val SPARK_CLASS_REGEX = """^spark(\.api\.java)?(\.rdd)?\.[A-Z]""".r - - private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, - val firstUserLine: Int, val firstUserClass: String) - - /** - * When called inside a class in the spark package, returns the name of the user code class - * (outside the spark package) that called into Spark, as well as which Spark method they called. - * This is used, for example, to tell users where in their code each RDD got created. - */ - def getCallSiteInfo: CallSiteInfo = { - val trace = Thread.currentThread.getStackTrace().filter( el => - (!el.getMethodName.contains("getStackTrace"))) - - // Keep crawling up the stack trace until we find the first function not inside of the spark - // package. We track the last (shallowest) contiguous Spark method. This might be an RDD - // transformation, a SparkContext function (such as parallelize), or anything else that leads - // to instantiation of an RDD. We also track the first (deepest) user method, file, and line. - var lastSparkMethod = "" - var firstUserFile = "" - var firstUserLine = 0 - var finished = false - var firstUserClass = "" - - for (el <- trace) { - if (!finished) { - if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName) != None) { - lastSparkMethod = if (el.getMethodName == "") { - // Spark method is a constructor; get its class name - el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1) - } else { - el.getMethodName - } - } - else { - firstUserLine = el.getLineNumber - firstUserFile = el.getFileName - firstUserClass = el.getClassName - finished = true - } - } - } - new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) - } - - def formatSparkCallSite = { - val callSiteInfo = getCallSiteInfo - "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile, - callSiteInfo.firstUserLine) - } - - /** Return a string containing part of a file from byte 'start' to 'end'. */ - def offsetBytes(path: String, start: Long, end: Long): String = { - val file = new File(path) - val length = file.length() - val effectiveEnd = math.min(length, end) - val effectiveStart = math.max(0, start) - val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt) - val stream = new FileInputStream(file) - - stream.skip(effectiveStart) - stream.read(buff) - stream.close() - Source.fromBytes(buff).mkString - } - - /** - * Clone an object using a Spark serializer. - */ - def clone[T](value: T, serializer: SerializerInstance): T = { - serializer.deserialize[T](serializer.serialize(value)) - } - - /** - * Detect whether this thread might be executing a shutdown hook. Will always return true if - * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. - * if System.exit was just called by a concurrent thread). - * - * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing - * an IllegalStateException. - */ - def inShutdown(): Boolean = { - try { - val hook = new Thread { - override def run() {} - } - Runtime.getRuntime.addShutdownHook(hook) - Runtime.getRuntime.removeShutdownHook(hook) - } catch { - case ise: IllegalStateException => return true - } - return false - } - - def isSpace(c: Char): Boolean = { - " \t\r\n".indexOf(c) != -1 - } - - /** - * Split a string of potentially quoted arguments from the command line the way that a shell - * would do it to determine arguments to a command. For example, if the string is 'a "b c" d', - * then it would be parsed as three arguments: 'a', 'b c' and 'd'. - */ - def splitCommandString(s: String): Seq[String] = { - val buf = new ArrayBuffer[String] - var inWord = false - var inSingleQuote = false - var inDoubleQuote = false - var curWord = new StringBuilder - def endWord() { - buf += curWord.toString - curWord.clear() - } - var i = 0 - while (i < s.length) { - var nextChar = s.charAt(i) - if (inDoubleQuote) { - if (nextChar == '"') { - inDoubleQuote = false - } else if (nextChar == '\\') { - if (i < s.length - 1) { - // Append the next character directly, because only " and \ may be escaped in - // double quotes after the shell's own expansion - curWord.append(s.charAt(i + 1)) - i += 1 - } - } else { - curWord.append(nextChar) - } - } else if (inSingleQuote) { - if (nextChar == '\'') { - inSingleQuote = false - } else { - curWord.append(nextChar) - } - // Backslashes are not treated specially in single quotes - } else if (nextChar == '"') { - inWord = true - inDoubleQuote = true - } else if (nextChar == '\'') { - inWord = true - inSingleQuote = true - } else if (!isSpace(nextChar)) { - curWord.append(nextChar) - inWord = true - } else if (inWord && isSpace(nextChar)) { - endWord() - inWord = false - } - i += 1 - } - if (inWord || inDoubleQuote || inSingleQuote) { - endWord() - } - return buf - } - - /* Calculates 'x' modulo 'mod', takes to consideration sign of x, - * i.e. if 'x' is negative, than 'x' % 'mod' is negative too - * so function return (x % mod) + mod in that case. - */ - def nonNegativeMod(x: Int, mod: Int): Int = { - val rawMod = x % mod - rawMod + (if (rawMod < 0) mod else 0) - } -} diff --git a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala deleted file mode 100644 index 8ce7df6213..0000000000 --- a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import spark.RDD -import spark.SparkContext.doubleRDDToDoubleRDDFunctions -import spark.api.java.function.{Function => JFunction} -import spark.util.StatCounter -import spark.partial.{BoundedDouble, PartialResult} -import spark.storage.StorageLevel -import java.lang.Double -import spark.Partitioner - -class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] { - - override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]] - - override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x)) - - override def wrapRDD(rdd: RDD[Double]): JavaDoubleRDD = - new JavaDoubleRDD(rdd.map(_.doubleValue)) - - // Common RDD functions - - import JavaDoubleRDD.fromRDD - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def cache(): JavaDoubleRDD = fromRDD(srdd.cache()) - - /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. Can only be called once on each RDD. - */ - def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel)) - - // first() has to be overriden here in order for its return type to be Double instead of Object. - override def first(): Double = srdd.first() - - // Transformations (return a new RDD) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct()) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numPartitions)) - - /** - * Return a new RDD containing only the elements that satisfy a predicate. - */ - def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD = - fromRDD(srdd.filter(x => f(x).booleanValue())) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions)) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD = - fromRDD(srdd.coalesce(numPartitions, shuffle)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtract(other: JavaDoubleRDD): JavaDoubleRDD = - fromRDD(srdd.subtract(other)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaDoubleRDD, numPartitions: Int): JavaDoubleRDD = - fromRDD(srdd.subtract(other, numPartitions)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaDoubleRDD, p: Partitioner): JavaDoubleRDD = - fromRDD(srdd.subtract(other, p)) - - /** - * Return a sampled subset of this RDD. - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD = - fromRDD(srdd.sample(withReplacement, fraction, seed)) - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd)) - - // Double RDD functions - - /** Add up the elements in this RDD. */ - def sum(): Double = srdd.sum() - - /** - * Return a [[spark.util.StatCounter]] object that captures the mean, variance and count - * of the RDD's elements in one operation. - */ - def stats(): StatCounter = srdd.stats() - - /** Compute the mean of this RDD's elements. */ - def mean(): Double = srdd.mean() - - /** Compute the variance of this RDD's elements. */ - def variance(): Double = srdd.variance() - - /** Compute the standard deviation of this RDD's elements. */ - def stdev(): Double = srdd.stdev() - - /** - * Compute the sample standard deviation of this RDD's elements (which corrects for bias in - * estimating the standard deviation by dividing by N-1 instead of N). - */ - def sampleStdev(): Double = srdd.sampleStdev() - - /** - * Compute the sample variance of this RDD's elements (which corrects for bias in - * estimating the standard variance by dividing by N-1 instead of N). - */ - def sampleVariance(): Double = srdd.sampleVariance() - - /** Return the approximate mean of the elements in this RDD. */ - def meanApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = - srdd.meanApprox(timeout, confidence) - - /** (Experimental) Approximate operation to return the mean within a timeout. */ - def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout) - - /** (Experimental) Approximate operation to return the sum within a timeout. */ - def sumApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = - srdd.sumApprox(timeout, confidence) - - /** (Experimental) Approximate operation to return the sum within a timeout. */ - def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) -} - -object JavaDoubleRDD { - def fromRDD(rdd: RDD[scala.Double]): JavaDoubleRDD = new JavaDoubleRDD(rdd) - - implicit def toRDD(rdd: JavaDoubleRDD): RDD[scala.Double] = rdd.srdd -} diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala deleted file mode 100644 index effe6e5e0d..0000000000 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ /dev/null @@ -1,601 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import java.util.{List => JList} -import java.util.Comparator - -import scala.Tuple2 -import scala.collection.JavaConversions._ - -import com.google.common.base.Optional -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputFormat -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.conf.Configuration - -import spark.HashPartitioner -import spark.Partitioner -import spark.Partitioner._ -import spark.RDD -import spark.SparkContext.rddToPairRDDFunctions -import spark.api.java.function.{Function2 => JFunction2} -import spark.api.java.function.{Function => JFunction} -import spark.partial.BoundedDouble -import spark.partial.PartialResult -import spark.rdd.OrderedRDDFunctions -import spark.storage.StorageLevel - - -class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K], - implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { - - override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) - - override val classManifest: ClassManifest[(K, V)] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] - - import JavaPairRDD._ - - // Common RDD functions - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache()) - - /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. Can only be called once on each RDD. - */ - def persist(newLevel: StorageLevel): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.persist(newLevel)) - - // Transformations (return a new RDD) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct()) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numPartitions)) - - /** - * Return a new RDD containing only the elements that satisfy a predicate. - */ - def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue())) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.coalesce(numPartitions)) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] = - fromRDD(rdd.coalesce(numPartitions, shuffle)) - - /** - * Return a sampled subset of this RDD. - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed)) - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.union(other.rdd)) - - // first() has to be overridden here so that the generated method has the signature - // 'public scala.Tuple2 first()'; if the trait's definition is used, - // then the method has the signature 'public java.lang.Object first()', - // causing NoSuchMethodErrors at runtime. - override def first(): (K, V) = rdd.first() - - // Pair RDD functions - - /** - * Generic function to combine the elements for each key using a custom set of aggregation - * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: - * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. - * - * In addition, users can control the partitioning of the output RDD, and whether to perform - * map-side aggregation (if a mapper can produce multiple items with the same key). - */ - def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - partitioner: Partitioner): JavaPairRDD[K, C] = { - implicit val cm: ClassManifest[C] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] - fromRDD(rdd.combineByKey( - createCombiner, - mergeValue, - mergeCombiners, - partitioner - )) - } - - /** - * Simplified version of combineByKey that hash-partitions the output RDD. - */ - def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - numPartitions: Int): JavaPairRDD[K, C] = - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. - */ - def reduceByKey(partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = - fromRDD(rdd.reduceByKey(partitioner, func)) - - /** - * Merge the values for each key using an associative reduce function, but return the results - * immediately to the master as a Map. This will also perform the merging locally on each mapper - * before sending results to a reducer, similarly to a "combiner" in MapReduce. - */ - def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] = - mapAsJavaMap(rdd.reduceByKeyLocally(func)) - - /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) - - /** - * (Experimental) Approximate version of countByKey that can return a partial result if it does - * not finish within a timeout. - */ - def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout).map(mapAsJavaMap) - - /** - * (Experimental) Approximate version of countByKey that can return a partial result if it does - * not finish within a timeout. - */ - def countByKeyApprox(timeout: Long, confidence: Double = 0.95) - : PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = - fromRDD(rdd.foldByKey(zeroValue, partitioner)(func)) - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, numPartitions: Int, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = - fromRDD(rdd.foldByKey(zeroValue, numPartitions)(func)) - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = - fromRDD(rdd.foldByKey(zeroValue)(func)) - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. - */ - def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairRDD[K, V] = - fromRDD(rdd.reduceByKey(func, numPartitions)) - - /** - * Group the values for each key in the RDD into a single sequence. Allows controlling the - * partitioning of the resulting key-value pair RDD by passing a Partitioner. - */ - def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] = - fromRDD(groupByResultToJava(rdd.groupByKey(partitioner))) - - /** - * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with into `numPartitions` partitions. - */ - def groupByKey(numPartitions: Int): JavaPairRDD[K, JList[V]] = - fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions))) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = - fromRDD(rdd.subtract(other)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaPairRDD[K, V], numPartitions: Int): JavaPairRDD[K, V] = - fromRDD(rdd.subtract(other, numPartitions)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaPairRDD[K, V], p: Partitioner): JavaPairRDD[K, V] = - fromRDD(rdd.subtract(other, p)) - - /** - * Return a copy of the RDD partitioned using the specified partitioner. - */ - def partitionBy(partitioner: Partitioner): JavaPairRDD[K, V] = - fromRDD(rdd.partitionBy(partitioner)) - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. - */ - def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] = - fromRDD(rdd.join(other, partitioner)) - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to - * partition the output RDD. - */ - def leftOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner) - : JavaPairRDD[K, (V, Optional[W])] = { - val joinResult = rdd.leftOuterJoin(other, partitioner) - fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to - * partition the output RDD. - */ - def rightOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner) - : JavaPairRDD[K, (Optional[V], W)] = { - val joinResult = rdd.rightOuterJoin(other, partitioner) - fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) - } - - /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing - * partitioner/parallelism level. - */ - def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = { - implicit val cm: ClassManifest[C] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] - fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd))) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ - * parallelism level. - */ - def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = { - fromRDD(reduceByKey(defaultPartitioner(rdd), func)) - } - - /** - * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with the existing partitioner/parallelism level. - */ - def groupByKey(): JavaPairRDD[K, JList[V]] = - fromRDD(groupByResultToJava(rdd.groupByKey())) - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Performs a hash join across the cluster. - */ - def join[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, W)] = - fromRDD(rdd.join(other)) - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Performs a hash join across the cluster. - */ - def join[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, W)] = - fromRDD(rdd.join(other, numPartitions)) - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * using the existing partitioner/parallelism level. - */ - def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Optional[W])] = { - val joinResult = rdd.leftOuterJoin(other) - fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}) - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * into `numPartitions` partitions. - */ - def leftOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, Optional[W])] = { - val joinResult = rdd.leftOuterJoin(other, numPartitions) - fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD using the existing partitioner/parallelism level. - */ - def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Optional[V], W)] = { - val joinResult = rdd.rightOuterJoin(other) - fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD into the given number of partitions. - */ - def rightOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (Optional[V], W)] = { - val joinResult = rdd.rightOuterJoin(other, numPartitions) - fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) - } - - /** - * Return the key-value pairs in this RDD to the master as a Map. - */ - def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap()) - - /** - * Pass each value in the key-value pair RDD through a map function without changing the keys; - * this also retains the original RDD's partitioning. - */ - def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = { - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] - fromRDD(rdd.mapValues(f)) - } - - /** - * Pass each value in the key-value pair RDD through a flatMap function without changing the - * keys; this also retains the original RDD's partitioning. - */ - def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { - import scala.collection.JavaConverters._ - def fn = (x: V) => f.apply(x).asScala - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] - fromRDD(rdd.flatMapValues(fn)) - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner) - : JavaPairRDD[K, (JList[V], JList[W])] = - fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner))) - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], partitioner: Partitioner) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = - fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner))) - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = - fromRDD(cogroupResultToJava(rdd.cogroup(other))) - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2]) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = - fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2))) - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (JList[V], JList[W])] - = fromRDD(cogroupResultToJava(rdd.cogroup(other, numPartitions))) - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numPartitions: Int) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = - fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions))) - - /** Alias for cogroup. */ - def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = - fromRDD(cogroupResultToJava(rdd.groupWith(other))) - - /** Alias for cogroup. */ - def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2]) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = - fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2))) - - /** - * Return the list of values in the RDD for key `key`. This operation is done efficiently if the - * RDD has a known partitioner by only searching the partition that the key maps to. - */ - def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key)) - - /** Output the RDD to any Hadoop-supported file system. */ - def saveAsHadoopFile[F <: OutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F], - conf: JobConf) { - rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf) - } - - /** Output the RDD to any Hadoop-supported file system. */ - def saveAsHadoopFile[F <: OutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F]) { - rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass) - } - - /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */ - def saveAsHadoopFile[F <: OutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F], - codec: Class[_ <: CompressionCodec]) { - rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec) - } - - /** Output the RDD to any Hadoop-supported file system. */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F], - conf: Configuration) { - rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf) - } - - /** Output the RDD to any Hadoop-supported file system. */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F]) { - rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) - } - - /** - * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for - * that storage system. The JobConf should set an OutputFormat and any output paths required - * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop - * MapReduce job. - */ - def saveAsHadoopDataset(conf: JobConf) { - rdd.saveAsHadoopDataset(conf) - } - - /** - * Sort the RDD by key, so that each partition contains a sorted range of the elements in - * ascending order. Calling `collect` or `save` on the resulting RDD will return or output an - * ordered list of records (in the `save` case, they will be written to multiple `part-X` files - * in the filesystem, in order of the keys). - */ - def sortByKey(): JavaPairRDD[K, V] = sortByKey(true) - - /** - * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling - * `collect` or `save` on the resulting RDD will return or output an ordered list of records - * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in - * order of the keys). - */ - def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = { - val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] - sortByKey(comp, ascending) - } - - /** - * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling - * `collect` or `save` on the resulting RDD will return or output an ordered list of records - * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in - * order of the keys). - */ - def sortByKey(comp: Comparator[K]): JavaPairRDD[K, V] = sortByKey(comp, true) - - /** - * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling - * `collect` or `save` on the resulting RDD will return or output an ordered list of records - * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in - * order of the keys). - */ - def sortByKey(comp: Comparator[K], ascending: Boolean): JavaPairRDD[K, V] = { - class KeyOrdering(val a: K) extends Ordered[K] { - override def compare(b: K) = comp.compare(a, b) - } - implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x) - fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending)) - } - - /** - * Return an RDD with the keys of each tuple. - */ - def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1)) - - /** - * Return an RDD with the values of each tuple. - */ - def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2)) -} - -object JavaPairRDD { - def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K], - vcm: ClassManifest[T]): RDD[(K, JList[T])] = - rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _) - - def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K], - vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V], - Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2))) - - def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1], - Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1], - JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues( - (x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1), - seqAsJavaList(x._2), - seqAsJavaList(x._3))) - - def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd) - - implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd -} diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala deleted file mode 100644 index c0bf2cf568..0000000000 --- a/core/src/main/scala/spark/api/java/JavaRDD.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import spark._ -import spark.api.java.function.{Function => JFunction} -import spark.storage.StorageLevel - -class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends -JavaRDDLike[T, JavaRDD[T]] { - - override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) - - // Common RDD functions - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def cache(): JavaRDD[T] = wrapRDD(rdd.cache()) - - /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. This can only be used to assign a new storage level if the RDD does not - * have a storage level set yet.. - */ - def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel)) - - /** - * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. - */ - def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist()) - - // Transformations (return a new RDD) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct()) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions)) - - /** - * Return a new RDD containing only the elements that satisfy a predicate. - */ - def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] = - wrapRDD(rdd.filter((x => f(x).booleanValue()))) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int): JavaRDD[T] = rdd.coalesce(numPartitions) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] = - rdd.coalesce(numPartitions, shuffle) - - /** - * Return a sampled subset of this RDD. - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] = - wrapRDD(rdd.sample(withReplacement, fraction, seed)) - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaRDD[T], numPartitions: Int): JavaRDD[T] = - wrapRDD(rdd.subtract(other, numPartitions)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = - wrapRDD(rdd.subtract(other, p)) -} - -object JavaRDD { - - implicit def fromRDD[T: ClassManifest](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd) - - implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd -} - diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala deleted file mode 100644 index 2c2b138f16..0000000000 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ /dev/null @@ -1,426 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import java.util.{List => JList, Comparator} -import scala.Tuple2 -import scala.collection.JavaConversions._ - -import org.apache.hadoop.io.compress.CompressionCodec -import spark.{SparkContext, Partition, RDD, TaskContext} -import spark.api.java.JavaPairRDD._ -import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} -import spark.partial.{PartialResult, BoundedDouble} -import spark.storage.StorageLevel -import com.google.common.base.Optional - - -trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { - def wrapRDD(rdd: RDD[T]): This - - implicit val classManifest: ClassManifest[T] - - def rdd: RDD[T] - - /** Set of partitions in this RDD. */ - def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) - - /** The [[spark.SparkContext]] that this RDD was created on. */ - def context: SparkContext = rdd.context - - /** A unique ID for this RDD (within its SparkContext). */ - def id: Int = rdd.id - - /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ - def getStorageLevel: StorageLevel = rdd.getStorageLevel - - /** - * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. - * This should ''not'' be called by users directly, but is available for implementors of custom - * subclasses of RDD. - */ - def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] = - asJavaIterator(rdd.iterator(split, taskContext)) - - // Transformations (return a new RDD) - - /** - * Return a new RDD by applying a function to all elements of this RDD. - */ - def map[R](f: JFunction[T, R]): JavaRDD[R] = - new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType()) - - /** - * Return a new RDD by applying a function to all elements of this RDD. - */ - def map[R](f: DoubleFunction[T]): JavaDoubleRDD = - new JavaDoubleRDD(rdd.map(x => f(x).doubleValue())) - - /** - * Return a new RDD by applying a function to all elements of this RDD. - */ - def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] - new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType()) - } - - /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. - */ - def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { - import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType()) - } - - /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. - */ - def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { - import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) - } - - /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. - */ - def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] - JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType()) - } - - /** - * Return a new RDD by applying a function to each partition of this RDD. - */ - def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType()) - } - - /** - * Return a new RDD by applying a function to each partition of this RDD. - */ - def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) - } - - /** - * Return a new RDD by applying a function to each partition of this RDD. - */ - def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): - JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType()) - } - - /** - * Return an RDD created by coalescing all elements within each partition into an array. - */ - def glom(): JavaRDD[JList[T]] = - new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - - /** - * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of - * elements (a, b) where a is in `this` and b is in `other`. - */ - def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = - JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest, - other.classManifest) - - /** - * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements - * mapping to that key. - */ - def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[JList[T]] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]] - JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm) - } - - /** - * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements - * mapping to that key. - */ - def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[JList[T]] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]] - JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm) - } - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: String): JavaRDD[String] = rdd.pipe(command) - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: JList[String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command)) - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env)) - - /** - * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, - * second element in each RDD, etc. Assumes that the two RDDs have the *same number of - * partitions* and the *same number of elements in each partition* (e.g. one was made through - * a map on the other). - */ - def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = { - JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest) - } - - /** - * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by - * applying a function to the zipped partitions. Assumes that all the RDDs have the - * *same number of partitions*, but does *not* require them to have the same number - * of elements in each partition. - */ - def zipPartitions[U, V]( - other: JavaRDDLike[U, _], - f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { - def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.apply(asJavaIterator(x), asJavaIterator(y)).iterator()) - JavaRDD.fromRDD( - rdd.zipPartitions(other.rdd)(fn)(other.classManifest, f.elementType()))(f.elementType()) - } - - // Actions (launch a job to return a value to the user program) - - /** - * Applies a function f to all elements of this RDD. - */ - def foreach(f: VoidFunction[T]) { - val cleanF = rdd.context.clean(f) - rdd.foreach(cleanF) - } - - /** - * Return an array that contains all of the elements in this RDD. - */ - def collect(): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.collect().toSeq - new java.util.ArrayList(arr) - } - - /** - * Reduces the elements of this RDD using the specified commutative and associative binary operator. - */ - def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) - - /** - * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. - */ - def fold(zeroValue: T)(f: JFunction2[T, T, T]): T = - rdd.fold(zeroValue)(f) - - /** - * Aggregate the elements of each partition, and then the results for all the partitions, using - * given combine functions and a neutral "zero value". This function can return a different result - * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U - * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are - * allowed to modify and return their first argument instead of creating a new U to avoid memory - * allocation. - */ - def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U], - combOp: JFunction2[U, U, U]): U = - rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType) - - /** - * Return the number of elements in the RDD. - */ - def count(): Long = rdd.count() - - /** - * (Experimental) Approximate version of count() that returns a potentially incomplete result - * within a timeout, even if not all tasks have finished. - */ - def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = - rdd.countApprox(timeout, confidence) - - /** - * (Experimental) Approximate version of count() that returns a potentially incomplete result - * within a timeout, even if not all tasks have finished. - */ - def countApprox(timeout: Long): PartialResult[BoundedDouble] = - rdd.countApprox(timeout) - - /** - * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final - * combine step happens locally on the master, equivalent to running a single reduce task. - */ - def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) - - /** - * (Experimental) Approximate version of countByValue(). - */ - def countByValueApprox( - timeout: Long, - confidence: Double - ): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap) - - /** - * (Experimental) Approximate version of countByValue(). - */ - def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout).map(mapAsJavaMap) - - /** - * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so - * it will be slow if a lot of partitions are required. In that case, use collect() to get the - * whole RDD instead. - */ - def take(num: Int): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.take(num).toSeq - new java.util.ArrayList(arr) - } - - def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq - new java.util.ArrayList(arr) - } - - /** - * Return the first element in this RDD. - */ - def first(): T = rdd.first() - - /** - * Save this RDD as a text file, using string representations of elements. - */ - def saveAsTextFile(path: String) = rdd.saveAsTextFile(path) - - - /** - * Save this RDD as a compressed text file, using string representations of elements. - */ - def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) = - rdd.saveAsTextFile(path, codec) - - /** - * Save this RDD as a SequenceFile of serialized objects. - */ - def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) - - /** - * Creates tuples of the elements in this RDD by applying `f`. - */ - def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = { - implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - JavaPairRDD.fromRDD(rdd.keyBy(f)) - } - - /** - * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint - * directory set with SparkContext.setCheckpointDir() and all references to its parent - * RDDs will be removed. This function must be called before any job has been - * executed on this RDD. It is strongly recommended that this RDD is persisted in - * memory, otherwise saving it on a file will require recomputation. - */ - def checkpoint() = rdd.checkpoint() - - /** - * Return whether this RDD has been checkpointed or not - */ - def isCheckpointed: Boolean = rdd.isCheckpointed - - /** - * Gets the name of the file to which this RDD was checkpointed - */ - def getCheckpointFile(): Optional[String] = { - JavaUtils.optionToOptional(rdd.getCheckpointFile) - } - - /** A description of this RDD and its recursive dependencies for debugging. */ - def toDebugString(): String = { - rdd.toDebugString - } - - /** - * Returns the top K elements from this RDD as defined by - * the specified Comparator[T]. - * @param num the number of top elements to return - * @param comp the comparator that defines the order - * @return an array of top elements - */ - def top(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) - } - - /** - * Returns the top K elements from this RDD using the - * natural ordering for T. - * @param num the number of top elements to return - * @return an array of top elements - */ - def top(num: Int): JList[T] = { - val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] - top(num, comp) - } - - /** - * Returns the first K elements from this RDD as defined by - * the specified Comparator[T] and maintains the order. - * @param num the number of top elements to return - * @param comp the comparator that defines the order - * @return an array of top elements - */ - def takeOrdered(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) - } - - /** - * Returns the first K elements from this RDD using the - * natural ordering for T while maintain the order. - * @param num the number of top elements to return - * @return an array of top elements - */ - def takeOrdered(num: Int): JList[T] = { - val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] - takeOrdered(num, comp) - } -} diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala deleted file mode 100644 index 29d57004b5..0000000000 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ /dev/null @@ -1,418 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import java.util.{Map => JMap} - -import scala.collection.JavaConversions -import scala.collection.JavaConversions._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} - -import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} -import spark.SparkContext.IntAccumulatorParam -import spark.SparkContext.DoubleAccumulatorParam -import spark.broadcast.Broadcast - -import com.google.common.base.Optional - -/** - * A Java-friendly version of [[spark.SparkContext]] that returns [[spark.api.java.JavaRDD]]s and - * works with Java collections instead of Scala ones. - */ -class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround { - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI - */ - def this(master: String, appName: String) = this(new SparkContext(master, appName)) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI - * @param sparkHome The SPARK_HOME directory on the slave nodes - * @param jarFile JAR file to send to the cluster. This can be a path on the local file system - * or an HDFS, HTTP, HTTPS, or FTP URL. - */ - def this(master: String, appName: String, sparkHome: String, jarFile: String) = - this(new SparkContext(master, appName, sparkHome, Seq(jarFile))) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI - * @param sparkHome The SPARK_HOME directory on the slave nodes - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - */ - def this(master: String, appName: String, sparkHome: String, jars: Array[String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq)) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI - * @param sparkHome The SPARK_HOME directory on the slave nodes - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - * @param environment Environment variables to set on worker nodes - */ - def this(master: String, appName: String, sparkHome: String, jars: Array[String], - environment: JMap[String, String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment)) - - private[spark] val env = sc.env - - /** Distribute a local Scala collection to form an RDD. */ - def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) - } - - /** Distribute a local Scala collection to form an RDD. */ - def parallelize[T](list: java.util.List[T]): JavaRDD[T] = - parallelize(list, sc.defaultParallelism) - - /** Distribute a local Scala collection to form an RDD. */ - def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int) - : JavaPairRDD[K, V] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] - JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) - } - - /** Distribute a local Scala collection to form an RDD. */ - def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]]): JavaPairRDD[K, V] = - parallelizePairs(list, sc.defaultParallelism) - - /** Distribute a local Scala collection to form an RDD. */ - def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD = - JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()), - numSlices)) - - /** Distribute a local Scala collection to form an RDD. */ - def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD = - parallelizeDoubles(list, sc.defaultParallelism) - - /** - * Read a text file from HDFS, a local file system (available on all nodes), or any - * Hadoop-supported file system URI, and return it as an RDD of Strings. - */ - def textFile(path: String): JavaRDD[String] = sc.textFile(path) - - /** - * Read a text file from HDFS, a local file system (available on all nodes), or any - * Hadoop-supported file system URI, and return it as an RDD of Strings. - */ - def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits) - - /**Get an RDD for a Hadoop SequenceFile with given key and value types. */ - def sequenceFile[K, V](path: String, - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int - ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits)) - } - - /**Get an RDD for a Hadoop SequenceFile. */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): - JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass)) - } - - /** - * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and - * BytesWritable values that contain a serialized partition. This is still an experimental storage - * format and may not be supported exactly as is in future Spark releases. It will also be pretty - * slow if you use the default serializer (Java serialization), though the nice thing about it is - * that there's very little effort required to save arbitrary objects. - */ - def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - sc.objectFile(path, minSplits)(cm) - } - - /** - * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and - * BytesWritable values that contain a serialized partition. This is still an experimental storage - * format and may not be supported exactly as is in future Spark releases. It will also be pretty - * slow if you use the default serializer (Java serialization), though the nice thing about it is - * that there's very little effort required to save arbitrary objects. - */ - def objectFile[T](path: String): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - sc.objectFile(path)(cm) - } - - /** - * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any - * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, - * etc). - */ - def hadoopRDD[K, V, F <: InputFormat[K, V]]( - conf: JobConf, - inputFormatClass: Class[F], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int - ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits)) - } - - /** - * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any - * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, - * etc). - */ - def hadoopRDD[K, V, F <: InputFormat[K, V]]( - conf: JobConf, - inputFormatClass: Class[F], - keyClass: Class[K], - valueClass: Class[V] - ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)) - } - - /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ - def hadoopFile[K, V, F <: InputFormat[K, V]]( - path: String, - inputFormatClass: Class[F], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int - ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)) - } - - /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ - def hadoopFile[K, V, F <: InputFormat[K, V]]( - path: String, - inputFormatClass: Class[F], - keyClass: Class[K], - valueClass: Class[V] - ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.hadoopFile(path, - inputFormatClass, keyClass, valueClass)) - } - - /** - * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat - * and extra configuration options to pass to the input format. - */ - def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( - path: String, - fClass: Class[F], - kClass: Class[K], - vClass: Class[V], - conf: Configuration): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(kClass) - implicit val vcm = ClassManifest.fromClass(vClass) - new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf)) - } - - /** - * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat - * and extra configuration options to pass to the input format. - */ - def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( - conf: Configuration, - fClass: Class[F], - kClass: Class[K], - vClass: Class[V]): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(kClass) - implicit val vcm = ClassManifest.fromClass(vClass) - new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)) - } - - /** Build the union of two or more RDDs. */ - override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { - val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) - implicit val cm: ClassManifest[T] = first.classManifest - sc.union(rdds)(cm) - } - - /** Build the union of two or more RDDs. */ - override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) - : JavaPairRDD[K, V] = { - val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) - implicit val cm: ClassManifest[(K, V)] = first.classManifest - implicit val kcm: ClassManifest[K] = first.kManifest - implicit val vcm: ClassManifest[V] = first.vManifest - new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm) - } - - /** Build the union of two or more RDDs. */ - override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = { - val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd) - new JavaDoubleRDD(sc.union(rdds)) - } - - /** - * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = - sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] - - /** - * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] - - /** - * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) - - /** - * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - def accumulator(initialValue: Double): Accumulator[java.lang.Double] = - doubleAccumulator(initialValue) - - /** - * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = - sc.accumulator(initialValue)(accumulatorParam) - - /** - * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can - * "add" values with `add`. Only the master can access the accumuable's `value`. - */ - def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = - sc.accumulable(initialValue)(param) - - /** - * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for - * reading it in distributed functions. The variable will be sent to each cluster only once. - */ - def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value) - - /** Shut down the SparkContext. */ - def stop() { - sc.stop() - } - - /** - * Get Spark's home location from either a value set through the constructor, - * or the spark.home Java property, or the SPARK_HOME environment variable - * (in that order of preference). If neither of these is set, return None. - */ - def getSparkHome(): Optional[String] = JavaUtils.optionToOptional(sc.getSparkHome()) - - /** - * Add a file to be downloaded with this Spark job on every node. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, - * use `SparkFiles.get(path)` to find its download location. - */ - def addFile(path: String) { - sc.addFile(path) - } - - /** - * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. - */ - def addJar(path: String) { - sc.addJar(path) - } - - /** - * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to - * any new nodes. - */ - def clearJars() { - sc.clearJars() - } - - /** - * Clear the job's list of files added by `addFile` so that they do not get downloaded to - * any new nodes. - */ - def clearFiles() { - sc.clearFiles() - } - - /** - * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. - */ - def hadoopConfiguration(): Configuration = { - sc.hadoopConfiguration - } - - /** - * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists and useExisting is set to true, then the - * exisiting directory will be used. Otherwise an exception will be thrown to - * prevent accidental overriding of checkpoint files in the existing directory. - */ - def setCheckpointDir(dir: String, useExisting: Boolean) { - sc.setCheckpointDir(dir, useExisting) - } - - /** - * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists, an exception will be thrown to prevent accidental - * overriding of checkpoint files. - */ - def setCheckpointDir(dir: String) { - sc.setCheckpointDir(dir) - } - - protected def checkpointFile[T](path: String): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - new JavaRDD(sc.checkpointFile(path)) - } -} - -object JavaSparkContext { - implicit def fromSparkContext(sc: SparkContext): JavaSparkContext = new JavaSparkContext(sc) - - implicit def toSparkContext(jsc: JavaSparkContext): SparkContext = jsc.sc -} diff --git a/core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java deleted file mode 100644 index 42b1de01b1..0000000000 --- a/core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java; - -import java.util.Arrays; -import java.util.ArrayList; -import java.util.List; - -// See -// http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html -abstract class JavaSparkContextVarargsWorkaround { - public JavaRDD union(JavaRDD... rdds) { - if (rdds.length == 0) { - throw new IllegalArgumentException("Union called on empty list"); - } - ArrayList> rest = new ArrayList>(rdds.length - 1); - for (int i = 1; i < rdds.length; i++) { - rest.add(rdds[i]); - } - return union(rdds[0], rest); - } - - public JavaDoubleRDD union(JavaDoubleRDD... rdds) { - if (rdds.length == 0) { - throw new IllegalArgumentException("Union called on empty list"); - } - ArrayList rest = new ArrayList(rdds.length - 1); - for (int i = 1; i < rdds.length; i++) { - rest.add(rdds[i]); - } - return union(rdds[0], rest); - } - - public JavaPairRDD union(JavaPairRDD... rdds) { - if (rdds.length == 0) { - throw new IllegalArgumentException("Union called on empty list"); - } - ArrayList> rest = new ArrayList>(rdds.length - 1); - for (int i = 1; i < rdds.length; i++) { - rest.add(rdds[i]); - } - return union(rdds[0], rest); - } - - // These methods take separate "first" and "rest" elements to avoid having the same type erasure - abstract public JavaRDD union(JavaRDD first, List> rest); - abstract public JavaDoubleRDD union(JavaDoubleRDD first, List rest); - abstract public JavaPairRDD union(JavaPairRDD first, List> rest); -} diff --git a/core/src/main/scala/spark/api/java/JavaUtils.scala b/core/src/main/scala/spark/api/java/JavaUtils.scala deleted file mode 100644 index ffc131ac83..0000000000 --- a/core/src/main/scala/spark/api/java/JavaUtils.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import com.google.common.base.Optional - -object JavaUtils { - def optionToOptional[T](option: Option[T]): Optional[T] = - option match { - case Some(value) => Optional.of(value) - case None => Optional.absent() - } -} diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java deleted file mode 100644 index f385636e83..0000000000 --- a/core/src/main/scala/spark/api/java/StorageLevels.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java; - -import spark.storage.StorageLevel; - -/** - * Expose some commonly useful storage level constants. - */ -public class StorageLevels { - public static final StorageLevel NONE = new StorageLevel(false, false, false, 1); - public static final StorageLevel DISK_ONLY = new StorageLevel(true, false, false, 1); - public static final StorageLevel DISK_ONLY_2 = new StorageLevel(true, false, false, 2); - public static final StorageLevel MEMORY_ONLY = new StorageLevel(false, true, true, 1); - public static final StorageLevel MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2); - public static final StorageLevel MEMORY_ONLY_SER = new StorageLevel(false, true, false, 1); - public static final StorageLevel MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2); - public static final StorageLevel MEMORY_AND_DISK = new StorageLevel(true, true, true, 1); - public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2); - public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1); - public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2); - - /** - * Create a new StorageLevel object. - * @param useDisk saved to disk, if true - * @param useMemory saved to memory, if true - * @param deserialized saved as deserialized objects, if true - * @param replication replication factor - */ - public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) { - return StorageLevel.apply(useDisk, useMemory, deserialized, replication); - } -} diff --git a/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java deleted file mode 100644 index 8bc88d757f..0000000000 --- a/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - - -import scala.runtime.AbstractFunction1; - -import java.io.Serializable; - -/** - * A function that returns zero or more records of type Double from each input record. - */ -// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is -// overloaded for both FlatMapFunction and DoubleFlatMapFunction. -public abstract class DoubleFlatMapFunction extends AbstractFunction1> - implements Serializable { - - public abstract Iterable call(T t); - - @Override - public final Iterable apply(T t) { return call(t); } -} diff --git a/core/src/main/scala/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/spark/api/java/function/DoubleFunction.java deleted file mode 100644 index 1aa1e5dae0..0000000000 --- a/core/src/main/scala/spark/api/java/function/DoubleFunction.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - - -import scala.runtime.AbstractFunction1; - -import java.io.Serializable; - -/** - * A function that returns Doubles, and can be used to construct DoubleRDDs. - */ -// DoubleFunction does not extend Function because some UDF functions, like map, -// are overloaded for both Function and DoubleFunction. -public abstract class DoubleFunction extends WrappedFunction1 - implements Serializable { - - public abstract Double call(T t) throws Exception; -} diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala deleted file mode 100644 index 9eb0cfe3f9..0000000000 --- a/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function - -/** - * A function that returns zero or more output records from each input record. - */ -abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] { - @throws(classOf[Exception]) - def call(x: T) : java.lang.Iterable[R] - - def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]] -} diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala deleted file mode 100644 index dda98710c2..0000000000 --- a/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function - -/** - * A function that takes two inputs and returns zero or more output records. - */ -abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { - @throws(classOf[Exception]) - def call(a: A, b:B) : java.lang.Iterable[C] - - def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]] -} diff --git a/core/src/main/scala/spark/api/java/function/Function.java b/core/src/main/scala/spark/api/java/function/Function.java deleted file mode 100644 index 2a2ea0aacf..0000000000 --- a/core/src/main/scala/spark/api/java/function/Function.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction1; - -import java.io.Serializable; - - -/** - * Base class for functions whose return types do not create special RDDs. PairFunction and - * DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed - * when mapping RDDs of other types. - */ -public abstract class Function extends WrappedFunction1 implements Serializable { - public abstract R call(T t) throws Exception; - - public ClassManifest returnType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } -} - diff --git a/core/src/main/scala/spark/api/java/function/Function2.java b/core/src/main/scala/spark/api/java/function/Function2.java deleted file mode 100644 index 952d31ece4..0000000000 --- a/core/src/main/scala/spark/api/java/function/Function2.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction2; - -import java.io.Serializable; - -/** - * A two-argument function that takes arguments of type T1 and T2 and returns an R. - */ -public abstract class Function2 extends WrappedFunction2 - implements Serializable { - - public abstract R call(T1 t1, T2 t2) throws Exception; - - public ClassManifest returnType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } -} - diff --git a/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java deleted file mode 100644 index 4aad602da3..0000000000 --- a/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - -import scala.Tuple2; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction1; - -import java.io.Serializable; - -/** - * A function that returns zero or more key-value pair records from each input record. The - * key-value pairs are represented as scala.Tuple2 objects. - */ -// PairFlatMapFunction does not extend FlatMapFunction because flatMap is -// overloaded for both FlatMapFunction and PairFlatMapFunction. -public abstract class PairFlatMapFunction - extends WrappedFunction1>> - implements Serializable { - - public abstract Iterable> call(T t) throws Exception; - - public ClassManifest keyType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } - - public ClassManifest valueType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } -} diff --git a/core/src/main/scala/spark/api/java/function/PairFunction.java b/core/src/main/scala/spark/api/java/function/PairFunction.java deleted file mode 100644 index ccfe64ecf1..0000000000 --- a/core/src/main/scala/spark/api/java/function/PairFunction.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - -import scala.Tuple2; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction1; - -import java.io.Serializable; - -/** - * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs. - */ -// PairFunction does not extend Function because some UDF functions, like map, -// are overloaded for both Function and PairFunction. -public abstract class PairFunction - extends WrappedFunction1> - implements Serializable { - - public abstract Tuple2 call(T t) throws Exception; - - public ClassManifest keyType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } - - public ClassManifest valueType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } -} diff --git a/core/src/main/scala/spark/api/java/function/VoidFunction.scala b/core/src/main/scala/spark/api/java/function/VoidFunction.scala deleted file mode 100644 index f6fc0b0f7d..0000000000 --- a/core/src/main/scala/spark/api/java/function/VoidFunction.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function - -/** - * A function with no return value. - */ -// This allows Java users to write void methods without having to return Unit. -abstract class VoidFunction[T] extends Serializable { - @throws(classOf[Exception]) - def call(t: T) : Unit -} - -// VoidFunction cannot extend AbstractFunction1 (because that would force users to explicitly -// return Unit), so it is implicitly converted to a Function1[T, Unit]: -object VoidFunction { - implicit def toFunction[T](f: VoidFunction[T]) : Function1[T, Unit] = ((x : T) => f.call(x)) -} diff --git a/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala b/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala deleted file mode 100644 index 1758a38c4e..0000000000 --- a/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function - -import scala.runtime.AbstractFunction1 - -/** - * Subclass of Function1 for ease of calling from Java. The main thing it does is re-expose the - * apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply - * isn't marked to allow that). - */ -private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] { - @throws(classOf[Exception]) - def call(t: T): R - - final def apply(t: T): R = call(t) -} diff --git a/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala b/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala deleted file mode 100644 index b093567d2c..0000000000 --- a/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function - -import scala.runtime.AbstractFunction2 - -/** - * Subclass of Function2 for ease of calling from Java. The main thing it does is re-expose the - * apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply - * isn't marked to allow that). - */ -private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] { - @throws(classOf[Exception]) - def call(t1: T1, t2: T2): R - - final def apply(t1: T1, t2: T2): R = call(t1, t2) -} diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala deleted file mode 100644 index ac112b8c2c..0000000000 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.python - -import spark.Partitioner -import spark.Utils -import java.util.Arrays - -/** - * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. - * - * Stores the unique id() of the Python-side partitioning function so that it is incorporated into - * equality comparisons. Correctness requires that the id is a unique identifier for the - * lifetime of the program (i.e. that it is not re-used as the id of a different partitioning - * function). This can be ensured by using the Python id() function and maintaining a reference - * to the Python partitioning function so that its id() is not reused. - */ -private[spark] class PythonPartitioner( - override val numPartitions: Int, - val pyPartitionFunctionId: Long) - extends Partitioner { - - override def getPartition(key: Any): Int = key match { - case null => 0 - case key: Array[Byte] => Utils.nonNegativeMod(Arrays.hashCode(key), numPartitions) - case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions) - } - - override def equals(other: Any): Boolean = other match { - case h: PythonPartitioner => - h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId - case _ => - false - } -} diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala deleted file mode 100644 index 49671437d0..0000000000 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.python - -import java.io._ -import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} - -import scala.collection.JavaConversions._ - -import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} -import spark.broadcast.Broadcast -import spark._ -import spark.rdd.PipedRDD - - -private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], - command: Seq[String], - envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, - pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) - extends RDD[Array[Byte]](parent) { - - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec, - broadcastVars, accumulator) - - override def getPartitions = parent.partitions - - override val partitioner = if (preservePartitoning) parent.partitioner else None - - - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val startTime = System.currentTimeMillis - val env = SparkEnv.get - val worker = env.createPythonWorker(pythonExec, envVars.toMap) - - // Start a thread to feed the process input from our parent's iterator - new Thread("stdin writer for " + pythonExec) { - override def run() { - try { - SparkEnv.set(env) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - val printOut = new PrintWriter(stream) - // Partition index - dataOut.writeInt(split.index) - // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) - // Broadcast variables - dataOut.writeInt(broadcastVars.length) - for (broadcast <- broadcastVars) { - dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) - } - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (f <- pythonIncludes) { - PythonRDD.writeAsPickle(f, dataOut) - } - dataOut.flush() - // Serialized user code - for (elem <- command) { - printOut.println(elem) - } - printOut.flush() - // Data values - for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dataOut) - } - dataOut.flush() - printOut.flush() - worker.shutdownOutput() - } catch { - case e: IOException => - // This can happen for legitimate reasons if the Python code stops returning data before we are done - // passing elements through, e.g., for take(). Just log a message to say it happened. - logInfo("stdin writer to Python finished early") - logDebug("stdin writer to Python finished early", e) - } - } - }.start() - - // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - return new Iterator[Array[Byte]] { - def next(): Array[Byte] = { - val obj = _nextObj - if (hasNext) { - // FIXME: can deadlock if worker is waiting for us to - // respond to current message (currently irrelevant because - // output is shutdown before we read any input) - _nextObj = read() - } - obj - } - - private def read(): Array[Byte] = { - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - case -3 => - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) - read - case -2 => - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - throw new PythonException(new String(obj)) - case -1 => - // We've finished the data section of the output, but we can still - // read some accumulator updates; let's do that, breaking when we - // get a negative length record. - var len2 = stream.readInt() - while (len2 >= 0) { - val update = new Array[Byte](len2) - stream.readFully(update) - accumulator += Collections.singletonList(update) - len2 = stream.readInt() - } - new Array[Byte](0) - } - } catch { - case eof: EOFException => { - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } - case e => throw e - } - } - - var _nextObj = read() - - def hasNext = _nextObj.length != 0 - } - } - - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) -} - -/** Thrown for exceptions in user Python code. */ -private class PythonException(msg: String) extends Exception(msg) - -/** - * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. - * This is used by PySpark's shuffle operations. - */ -private class PairwiseRDD(prev: RDD[Array[Byte]]) extends - RDD[(Array[Byte], Array[Byte])](prev) { - override def getPartitions = prev.partitions - override def compute(split: Partition, context: TaskContext) = - prev.iterator(split, context).grouped(2).map { - case Seq(a, b) => (a, b) - case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) - } - val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) -} - -private[spark] object PythonRDD { - - /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ - def stripPickle(arr: Array[Byte]) : Array[Byte] = { - arr.slice(2, arr.length - 1) - } - - /** - * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. - * The data format is a 32-bit integer representing the pickled object's length (in bytes), - * followed by the pickled data. - * - * Pickle module: - * - * http://docs.python.org/2/library/pickle.html - * - * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: - * - * http://hg.python.org/cpython/file/2.6/Lib/pickle.py - * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py - * - * @param elem the object to write - * @param dOut a data output stream - */ - def writeAsPickle(elem: Any, dOut: DataOutputStream) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { - val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] - val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t._1)) - dOut.write(PythonRDD.stripPickle(t._2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new SparkException("Unexpected RDD type") - } - } - - def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : - JavaRDD[Array[Byte]] = { - val file = new DataInputStream(new FileInputStream(filename)) - val objs = new collection.mutable.ArrayBuffer[Array[Byte]] - try { - while (true) { - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - objs.append(obj) - } - } catch { - case eof: EOFException => {} - case e => throw e - } - JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) - } - - def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { - import scala.collection.JavaConverters._ - writeIteratorToPickleFile(items.asScala, filename) - } - - def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { - val file = new DataOutputStream(new FileOutputStream(filename)) - for (item <- items) { - writeAsPickle(item, file) - } - file.close() - } - - def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { - implicit val cm : ClassManifest[T] = rdd.elementClassManifest - rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator - } -} - -private object Pickle { - val PROTO: Byte = 0x80.toByte - val TWO: Byte = 0x02.toByte - val BINUNICODE: Byte = 'X' - val STOP: Byte = '.' - val TUPLE2: Byte = 0x86.toByte - val EMPTY_LIST: Byte = ']' - val MARK: Byte = '(' - val APPENDS: Byte = 'e' -} - -private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { - override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") -} - -/** - * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it - * collects a list of pickled strings that we pass to Python through a socket. - */ -class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) - extends AccumulatorParam[JList[Array[Byte]]] { - - Utils.checkHost(serverHost, "Expected hostname") - - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - - override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList - - override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) - : JList[Array[Byte]] = { - if (serverHost == null) { - // This happens on the worker node, where we just want to remember all the updates - val1.addAll(val2) - val1 - } else { - // This happens on the master, where we pass the updates to Python through a socket - val socket = new Socket(serverHost, serverPort) - val in = socket.getInputStream - val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) - out.writeInt(val2.size) - for (array <- val2) { - out.writeInt(array.length) - out.write(array) - } - out.flush() - // Wait for a byte from the Python side as an acknowledgement - val byteRead = in.read() - if (byteRead == -1) { - throw new SparkException("EOF reached before Python server acknowledged") - } - socket.close() - null - } - } -} diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala deleted file mode 100644 index 14f8320678..0000000000 --- a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.python - -import java.io.{File, DataInputStream, IOException} -import java.net.{Socket, SocketException, InetAddress} - -import scala.collection.JavaConversions._ - -import spark._ - -private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) - extends Logging { - var daemon: Process = null - val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) - var daemonPort: Int = 0 - - def create(): Socket = { - synchronized { - // Start the daemon if it hasn't been started - startDaemon() - - // Attempt to connect, restart and retry once if it fails - try { - new Socket(daemonHost, daemonPort) - } catch { - case exc: SocketException => { - logWarning("Python daemon unexpectedly quit, attempting to restart") - stopDaemon() - startDaemon() - new Socket(daemonHost, daemonPort) - } - case e => throw e - } - } - } - - def stop() { - stopDaemon() - } - - private def startDaemon() { - synchronized { - // Is it already running? - if (daemon != null) { - return - } - - try { - // Create and start the daemon - val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") - val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) - val workerEnv = pb.environment() - workerEnv.putAll(envVars) - val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH") - workerEnv.put("PYTHONPATH", pythonPath) - daemon = pb.start() - - // Redirect the stderr to ours - new Thread("stderr reader for " + pythonExec) { - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME HACK: We copy the stream on the level of bytes to - // attempt to dodge encoding problems. - val in = daemon.getErrorStream - var buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() - - val in = new DataInputStream(daemon.getInputStream) - daemonPort = in.readInt() - - // Redirect further stdout output to our stderr - new Thread("stdout reader for " + pythonExec) { - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME HACK: We copy the stream on the level of bytes to - // attempt to dodge encoding problems. - var buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() - } catch { - case e => { - stopDaemon() - throw e - } - } - - // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly - // detect our disappearance. - } - } - - private def stopDaemon() { - synchronized { - // Request shutdown of existing daemon by sending SIGTERM - if (daemon != null) { - daemon.destroy() - } - - daemon = null - daemonPort = 0 - } - } -} diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala deleted file mode 100644 index 6f7d385379..0000000000 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ /dev/null @@ -1,1057 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.io._ -import java.net._ -import java.util.{BitSet, Comparator, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math - -import spark._ -import spark.storage.StorageLevel - -private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) - with Logging - with Serializable { - - def value = value_ - - def blockId: String = "broadcast_" + id - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var hasBlocksBitVector: BitSet = null - @transient var numCopiesSent: Array[Int] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = new AtomicInteger(0) - - // Used ONLY by driver to track how many unique blocks have been sent out - @transient var sentBlocks = new AtomicInteger(0) - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - - // Used only in driver - @transient var guideMR: GuideMultipleRequests = null - - // Used only in Workers - @transient var ttGuide: TalkToGuide = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks.set(variableInfo.totalBlocks) - - // Guide has all the blocks - hasBlocksBitVector = new BitSet(totalBlocks) - hasBlocksBitVector.set(0, totalBlocks) - - // Guide still hasn't sent any block - numCopiesSent = new Array[Int](totalBlocks) - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val driverSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - hasBlocksBitVector.synchronized { - driverSource.hasBlocksBitVector = hasBlocksBitVector - } - - // In the beginning, this is the only known source to Guide - listOfSources += driverSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - // Start local ServeMultipleRequests thread first - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - // Initialize variables in the worker node. Driver sends everything as 0/null - private def initializeWorkerVariables() { - arrayOfBlocks = null - hasBlocksBitVector = null - numCopiesSent = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = new AtomicInteger(0) - - listenPortLock = new Object - totalBlocksLock = new Object - - serveMR = null - ttGuide = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - listOfSources = ListBuffer[SourceInfo]() - - stopBroadcast = false - } - - private def getLocalSourceInfo: SourceInfo = { - // Wait till hostName and listenPort are OK - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Wait till totalBlocks and totalBytes are OK - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - var localSourceInfo = SourceInfo( - hostAddress, listenPort, totalBlocks, totalBytes) - - localSourceInfo.hasBlocks = hasBlocks.get - - hasBlocksBitVector.synchronized { - localSourceInfo.hasBlocksBitVector = hasBlocksBitVector - } - - return localSourceInfo - } - - // Add new SourceInfo to the listOfSources. Update if it exists already. - // Optimizing just by OR-ing the BitVectors was BAD for performance - private def addToListOfSources(newSourceInfo: SourceInfo) { - listOfSources.synchronized { - if (listOfSources.contains(newSourceInfo)) { - listOfSources = listOfSources - newSourceInfo - } - listOfSources += newSourceInfo - } - } - - private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) { - newSourceInfos.foreach { newSourceInfo => - addToListOfSources(newSourceInfo) - } - } - - class TalkToGuide(gInfo: SourceInfo) - extends Thread with Logging { - override def run() { - - // Keep exchaning information until all blocks have been received - while (hasBlocks.get < totalBlocks) { - talkOnce - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - } - - // Talk one more time to let the Guide know of reception completion - talkOnce - } - - // Connect to Guide and send this worker's information - private def talkOnce { - var clientSocketToGuide: Socket = null - var oosGuide: ObjectOutputStream = null - var oisGuide: ObjectInputStream = null - - clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) - oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream) - oosGuide.flush() - oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream) - - // Send local information - oosGuide.writeObject(getLocalSourceInfo) - oosGuide.flush() - - // Receive source information from Guide - var suitableSources = - oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logDebug("Received suitableSources from Driver " + suitableSources) - - addToListOfSources(suitableSources) - - oisGuide.close() - oosGuide.close() - clientSocketToGuide.close() - } - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Setup initial states of variables - totalBlocks = gInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - hasBlocksBitVector = new BitSet(totalBlocks) - numCopiesSent = new Array[Int](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = gInfo.totalBytes - - // Start ttGuide to periodically talk to the Guide - var ttGuide = new TalkToGuide(gInfo) - ttGuide.setDaemon(true) - ttGuide.start() - logInfo("TalkToGuide started...") - - // Start pController to run TalkToPeer threads - var pcController = new PeerChatterController - pcController.setDaemon(true) - pcController.start() - logInfo("PeerChatterController started...") - - // FIXME: Must fix this. This might never break if broadcast fails. - // We should be able to break and send false. Also need to kill threads - while (hasBlocks.get < totalBlocks) { - Thread.sleep(MultiTracker.MaxKnockInterval) - } - - return true - } - - class PeerChatterController - extends Thread with Logging { - private var peersNowTalking = ListBuffer[SourceInfo]() - // TODO: There is a possible bug with blocksInRequestBitVector when a - // certain bit is NOT unset upon failure resulting in an infinite loop. - private var blocksInRequestBitVector = new BitSet(totalBlocks) - - override def run() { - var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) - - while (hasBlocks.get < totalBlocks) { - var numThreadsToCreate = 0 - listOfSources.synchronized { - numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) - - threadPool.getActiveCount - } - - while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) { - var peerToTalkTo = pickPeerToTalkToRandom - - if (peerToTalkTo != null) - logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector) - else - logDebug("No peer chosen...") - - if (peerToTalkTo != null) { - threadPool.execute(new TalkToPeer(peerToTalkTo)) - - // Add to peersNowTalking. Remove in the thread. We have to do this - // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once - peersNowTalking.synchronized { peersNowTalking += peerToTalkTo } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before starting some more threads - Thread.sleep(MultiTracker.MinKnockInterval) - } - // Shutdown the thread pool - threadPool.shutdown() - } - - // Right now picking the one that has the most blocks this peer wants - // Also picking peer randomly if no one has anything interesting - private def pickPeerToTalkToRandom: SourceInfo = { - var curPeer: SourceInfo = null - var curMax = 0 - - logDebug("Picking peers to talk to...") - - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Select the peer that has the most blocks that this receiver does not - peersNotInUse.foreach { eachSource => - var tempHasBlocksBitVector: BitSet = null - hasBlocksBitVector.synchronized { - tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size) - tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector) - - if (tempHasBlocksBitVector.cardinality > curMax) { - curPeer = eachSource - curMax = tempHasBlocksBitVector.cardinality - } - } - - // Always picking randomly - if (curPeer == null && peersNotInUse.size > 0) { - // Pick uniformly the i'th required peer - var i = MultiTracker.ranGen.nextInt(peersNotInUse.size) - - var peerIter = peersNotInUse.iterator - curPeer = peerIter.next - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - } - - return curPeer - } - - // Picking peer with the weight of rare blocks it has - private def pickPeerToTalkToRarestFirst: SourceInfo = { - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Count the number of copies of each block in the neighborhood - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // A block is considered rare if there are at most 2 copies of that block - // This CONSTANT could be a function of the neighborhood size - var rareBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) { - rareBlocksIndices += i - } - } - - // Find peers with rare blocks - var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]() - var totalRareBlocks = 0 - - peersNotInUse.foreach { eachPeer => - var hasRareBlocks = 0 - rareBlocksIndices.foreach { rareBlock => - if (eachPeer.hasBlocksBitVector.get(rareBlock)) { - hasRareBlocks += 1 - } - } - - if (hasRareBlocks > 0) { - peersWithRareBlocks += ((eachPeer, hasRareBlocks)) - } - totalRareBlocks += hasRareBlocks - } - - // Select a peer from peersWithRareBlocks based on weight calculated from - // unique rare blocks - var selectedPeerToTalkTo: SourceInfo = null - - if (peersWithRareBlocks.size > 0) { - // Sort the peers based on how many rare blocks they have - peersWithRareBlocks.sortBy(_._2) - - var randomNumber = MultiTracker.ranGen.nextDouble - var tempSum = 0.0 - - var i = 0 - do { - tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks) - if (tempSum >= randomNumber) { - selectedPeerToTalkTo = peersWithRareBlocks(i)._1 - } - i += 1 - } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null) - } - - if (selectedPeerToTalkTo == null) { - selectedPeerToTalkTo = pickPeerToTalkToRandom - } - - return selectedPeerToTalkTo - } - - class TalkToPeer(peerToTalkTo: SourceInfo) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - override def run() { - // TODO: There is a possible bug here regarding blocksInRequestBitVector - var blockToAskFor = -1 - - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run() { - cleanUpConnections() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval) - - logInfo("TalkToPeer started... => " + peerToTalkTo) - - try { - // Connect to the source - peerSocketToSource = - new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - oisSource = - new ObjectInputStream(peerSocketToSource.getInputStream) - - // Receive latest SourceInfo from peerToTalkTo - var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo] - // Update listOfSources - addToListOfSources(newPeerToTalkTo) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - - var keepReceiving = true - - while (hasBlocks.get < totalBlocks && keepReceiving) { - blockToAskFor = - pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector) - - // No block to request - if (blockToAskFor < 0) { - // Nothing to receive from newPeerToTalkTo - keepReceiving = false - } else { - // Let other threads know that blockToAskFor is being requested - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor) - } - - // Start with sending the blockID - oosSource.writeObject(blockToAskFor) - oosSource.flush() - - // CHANGED: Driver might send some other block than the one - // requested to ensure fast spreading of all blocks. - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") - - if (!hasBlocksBitVector.get(bcBlock.blockID)) { - arrayOfBlocks(bcBlock.blockID) = bcBlock - - // Update the hasBlocksBitVector first - hasBlocksBitVector.synchronized { - hasBlocksBitVector.set(bcBlock.blockID) - hasBlocks.getAndIncrement - } - - // Some block(may NOT be blockToAskFor) has arrived. - // In any case, blockToAskFor is not in request any more - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - - // Reset blockToAskFor to -1. Else it will be considered missing - blockToAskFor = -1 - } - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - } - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logError("TalktoPeer had a " + e) - // FIXME: Remove 'newPeerToTalkTo' from listOfSources - // We probably should have the following in some form, but not - // really here. This exception can happen if the sender just breaks connection - // listOfSources.synchronized { - // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo) - // listOfSources = listOfSources - peerToTalkTo - // } - } - } finally { - // blockToAskFor != -1 => there was an exception - if (blockToAskFor != -1) { - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - } - - cleanUpConnections() - } - } - - // Right now it picks a block uniformly that this peer does not have - private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Pick uniformly the i'th required block - var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality) - var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) - - while (i > 0) { - pickedBlockIndex = - needBlocksBitVector.nextSetBit(pickedBlockIndex + 1) - i -= 1 - } - - return pickedBlockIndex - } - } - - // Pick the block that seems to be the rarest across sources - private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Count the number of copies for each block across all sources - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // Find the minimum - var minVal = Integer.MAX_VALUE - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) { - minVal = numCopiesPerBlock(i) - } - } - - // Find the blocks with the least copies that this peer does not have - var minBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) { - minBlocksIndices += i - } - } - - // Now select a random index from minBlocksIndices - if (minBlocksIndices.size == 0) { - return -1 - } else { - // Pick uniformly the i'th index - var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size) - return minBlocksIndices(i) - } - } - } - - private def cleanUpConnections() { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - - // Delete from peersNowTalking - peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo } - } - } - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close the socket here; else, thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - - // Shutdown the thread pool - threadPool.shutdown() - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - listOfSources.foreach { sourceInfo => - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Throw away whatever comes in - gisSource.readObject.asInstanceOf[SourceInfo] - - // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast - gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sourceInfo: SourceInfo = null - private var selectedSources: ListBuffer[SourceInfo] = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its information - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Select a suitable source and send it back to the worker - selectedSources = selectSuitableSources(sourceInfo) - logDebug("Sending selectedSources:" + selectedSources) - oos.writeObject(selectedSources) - oos.flush() - - // Add this source to the listOfSources - addToListOfSources(sourceInfo) - } catch { - case e: Exception => { - // Assuming exception caused by receiver failure: remove - if (listOfSources != null) { - listOfSources.synchronized { listOfSources -= sourceInfo } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Randomly select some sources to send back - private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { - var selectedSources = ListBuffer[SourceInfo]() - - // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' - // then add skipSourceInfo to setOfCompletedSources. Return blank. - if (skipSourceInfo.hasBlocks == totalBlocks) { - setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo } - return selectedSources - } - - listOfSources.synchronized { - if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) { - selectedSources = listOfSources.clone - } else { - var picksLeft = MultiTracker.MaxPeersInGuideResponse - var alreadyPicked = new BitSet(listOfSources.size) - - while (picksLeft > 0) { - var i = -1 - - do { - i = MultiTracker.ranGen.nextInt(listOfSources.size) - } while (alreadyPicked.get(i)) - - var peerIter = listOfSources.iterator - var curPeer = peerIter.next - - // Set the BitSet before i is decremented - alreadyPicked.set(i) - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - - selectedSources += curPeer - - picksLeft = picksLeft - 1 - } - } - } - - // Remove the receiving source (if present) - selectedSources = selectedSources - skipSourceInfo - - return selectedSources - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - // Server at most MultiTracker.MaxChatSlots peers - var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ServeSingleRequest is running") - - override def run() { - try { - // Send latest local SourceInfo to the receiver - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - oos.writeObject(getLocalSourceInfo) - oos.flush() - - // Receive latest SourceInfo from the receiver - var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - addToListOfSources(rxSourceInfo) - } - - val startTime = System.currentTimeMillis - var curTime = startTime - var keepSending = true - var numBlocksToSend = MultiTracker.MaxChatBlocks - - while (!stopBroadcast && keepSending && numBlocksToSend > 0) { - // Receive which block to send - var blockToSend = ois.readObject.asInstanceOf[Int] - - // If it is driver AND at least one copy of each block has not been - // sent out already, MODIFY blockToSend - if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) { - blockToSend = sentBlocks.getAndIncrement - } - - // Send the block - sendBlock(blockToSend) - rxSourceInfo.hasBlocksBitVector.set(blockToSend) - - numBlocksToSend -= 1 - - // Receive latest SourceInfo from the receiver - rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) - addToListOfSources(rxSourceInfo) - - curTime = System.currentTimeMillis - // Revoke sending only if there is anyone waiting in the queue - if (curTime - startTime >= MultiTracker.MaxChatTime && - threadPool.getQueue.size > 0) { - keepSending = false - } - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendBlock(blockToSend: Int) { - try { - oos.writeObject(arrayOfBlocks(blockToSend)) - oos.flush() - } catch { - case e: Exception => logError("sendBlock had a " + e) - } - logDebug("Sent block: " + blockToSend + " to " + clientSocket) - } - } - } -} - -private[spark] class BitTorrentBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new BitTorrentBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala deleted file mode 100644 index aba56a60ca..0000000000 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.io._ -import java.util.concurrent.atomic.AtomicLong - -import spark._ - -abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { - def value: T - - // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. - - override def toString = "spark.Broadcast(" + id + ")" -} - -private[spark] -class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable { - - private var initialized = false - private var broadcastFactory: BroadcastFactory = null - - initialize() - - // Called by SparkContext or Executor before using Broadcast - private def initialize() { - synchronized { - if (!initialized) { - val broadcastFactoryClass = System.getProperty( - "spark.broadcast.factory", "spark.broadcast.HttpBroadcastFactory") - - broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - - // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver) - - initialized = true - } - } - } - - def stop() { - broadcastFactory.stop() - } - - private val nextBroadcastId = new AtomicLong(0) - - def newBroadcast[T](value_ : T, isLocal: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) - - def isDriver = _isDriver -} diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala deleted file mode 100644 index d33d95c7d9..0000000000 --- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -/** - * An interface for all the broadcast implementations in Spark (to allow - * multiple broadcast implementations). SparkContext uses a user-specified - * BroadcastFactory implementation to instantiate a particular broadcast for the - * entire Spark job. - */ -private[spark] trait BroadcastFactory { - def initialize(isDriver: Boolean): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] - def stop(): Unit -} diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala deleted file mode 100644 index 138a8c21bc..0000000000 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.URL - -import it.unimi.dsi.fastutil.io.FastBufferedInputStream -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - -import spark.{HttpServer, Logging, SparkEnv, Utils} -import spark.io.CompressionCodec -import spark.storage.StorageLevel -import spark.util.{MetadataCleaner, TimeStampedHashSet} - - -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) with Logging with Serializable { - - def value = value_ - - def blockId: String = "broadcast_" + id - - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - if (!isLocal) { - HttpBroadcast.write(id, value_) - } - - // Called by JVM when deserializing an object - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => value_ = x.asInstanceOf[T] - case None => { - logInfo("Started reading broadcast variable " + id) - val start = System.nanoTime - value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - } -} - -private[spark] class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new HttpBroadcast[T](value_, isLocal, id) - - def stop() { HttpBroadcast.stop() } -} - -private object HttpBroadcast extends Logging { - private var initialized = false - - private var broadcastDir: File = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - private var serverUri: String = null - private var server: HttpServer = null - - private val files = new TimeStampedHashSet[String] - private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) - - private lazy val compressionCodec = CompressionCodec.createCodec() - - def initialize(isDriver: Boolean) { - synchronized { - if (!initialized) { - bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - compress = System.getProperty("spark.broadcast.compress", "true").toBoolean - if (isDriver) { - createServer() - } - serverUri = System.getProperty("spark.httpBroadcast.uri") - initialized = true - } - } - } - - def stop() { - synchronized { - if (server != null) { - server.stop() - server = null - } - initialized = false - cleaner.cancel() - } - } - - private def createServer() { - broadcastDir = Utils.createTempDir(Utils.getLocalDir) - server = new HttpServer(broadcastDir) - server.start() - serverUri = server.uri - System.setProperty("spark.httpBroadcast.uri", serverUri) - logInfo("Broadcast server started at " + serverUri) - } - - def write(id: Long, value: Any) { - val file = new File(broadcastDir, "broadcast-" + id) - val out: OutputStream = { - if (compress) { - compressionCodec.compressedOutputStream(new FileOutputStream(file)) - } else { - new FastBufferedOutputStream(new FileOutputStream(file), bufferSize) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() - files += file.getAbsolutePath - } - - def read[T](id: Long): T = { - val url = serverUri + "/broadcast-" + id - val in = { - if (compress) { - compressionCodec.compressedInputStream(new URL(url).openStream()) - } else { - new FastBufferedInputStream(new URL(url).openStream(), bufferSize) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() - obj - } - - def cleanup(cleanupTime: Long) { - val iterator = files.internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val (file, time) = (entry.getKey, entry.getValue) - if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) - } - } - } - } -} diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala deleted file mode 100644 index 7855d44e9b..0000000000 --- a/core/src/main/scala/spark/broadcast/MultiTracker.scala +++ /dev/null @@ -1,409 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.io._ -import java.net._ -import java.util.Random - -import scala.collection.mutable.Map - -import spark._ - -private object MultiTracker -extends Logging { - - // Tracker Messages - val REGISTER_BROADCAST_TRACKER = 0 - val UNREGISTER_BROADCAST_TRACKER = 1 - val FIND_BROADCAST_TRACKER = 2 - - // Map to keep track of guides of ongoing broadcasts - var valueToGuideMap = Map[Long, SourceInfo]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var _isDriver = false - - private var stopBroadcast = false - - private var trackMV: TrackMultipleValues = null - - def initialize(__isDriver: Boolean) { - synchronized { - if (!initialized) { - _isDriver = __isDriver - - if (isDriver) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - - // Set DriverHostAddress to the driver's IP address for the slaves to read - System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress) - } - - initialized = true - } - } - } - - def stop() { - stopBroadcast = true - } - - // Load common parameters - private var DriverHostAddress_ = System.getProperty( - "spark.MultiTracker.DriverHostAddress", "") - private var DriverTrackerPort_ = System.getProperty( - "spark.broadcast.driverTrackerPort", "11111").toInt - private var BlockSize_ = System.getProperty( - "spark.broadcast.blockSize", "4096").toInt * 1024 - private var MaxRetryCount_ = System.getProperty( - "spark.broadcast.maxRetryCount", "2").toInt - - private var TrackerSocketTimeout_ = System.getProperty( - "spark.broadcast.trackerSocketTimeout", "50000").toInt - private var ServerSocketTimeout_ = System.getProperty( - "spark.broadcast.serverSocketTimeout", "10000").toInt - - private var MinKnockInterval_ = System.getProperty( - "spark.broadcast.minKnockInterval", "500").toInt - private var MaxKnockInterval_ = System.getProperty( - "spark.broadcast.maxKnockInterval", "999").toInt - - // Load TreeBroadcast config params - private var MaxDegree_ = System.getProperty( - "spark.broadcast.maxDegree", "2").toInt - - // Load BitTorrentBroadcast config params - private var MaxPeersInGuideResponse_ = System.getProperty( - "spark.broadcast.maxPeersInGuideResponse", "4").toInt - - private var MaxChatSlots_ = System.getProperty( - "spark.broadcast.maxChatSlots", "4").toInt - private var MaxChatTime_ = System.getProperty( - "spark.broadcast.maxChatTime", "500").toInt - private var MaxChatBlocks_ = System.getProperty( - "spark.broadcast.maxChatBlocks", "1024").toInt - - private var EndGameFraction_ = System.getProperty( - "spark.broadcast.endGameFraction", "0.95").toDouble - - def isDriver = _isDriver - - // Common config params - def DriverHostAddress = DriverHostAddress_ - def DriverTrackerPort = DriverTrackerPort_ - def BlockSize = BlockSize_ - def MaxRetryCount = MaxRetryCount_ - - def TrackerSocketTimeout = TrackerSocketTimeout_ - def ServerSocketTimeout = ServerSocketTimeout_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - // TreeBroadcast configs - def MaxDegree = MaxDegree_ - - // BitTorrentBroadcast configs - def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ - - def MaxChatSlots = MaxChatSlots_ - def MaxChatTime = MaxChatTime_ - def MaxChatBlocks = MaxChatBlocks_ - - def EndGameFraction = EndGameFraction_ - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(DriverTrackerPort) - logInfo("TrackMultipleValues started at " + serverSocket) - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(TrackerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - if (stopBroadcast) { - logInfo("Stopping TrackMultipleValues...") - } - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // First, read message type - val messageType = ois.readObject.asInstanceOf[Int] - - if (messageType == REGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - // Receive hostAddress and listenPort - val gInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Add to the map - valueToGuideMap.synchronized { - valueToGuideMap += (id -> gInfo) - } - - logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == UNREGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - // Remove from the map - valueToGuideMap.synchronized { - valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault) - } - - logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == FIND_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - var gInfo = - if (valueToGuideMap.contains(id)) valueToGuideMap(id) - else SourceInfo("", SourceInfo.TxNotStartedRetry) - - logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort) - - // Send reply back - oos.writeObject(gInfo) - oos.flush() - } else { - throw new SparkException("Undefined messageType at TrackMultipleValues") - } - } catch { - case e: Exception => { - logError("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } - - def getGuideInfo(variableLong: Long): SourceInfo = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry) - - var retriesLeft = MultiTracker.MaxRetryCount - do { - try { - // Connect to the tracker to find out GuideInfo - clientSocketToTracker = - new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send messageType/intention - oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) - oosTracker.flush() - - // Send Long and receive GuideInfo - oosTracker.writeObject(variableLong) - oosTracker.flush() - gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] - } catch { - case e: Exception => logError("getGuideInfo had a " + e) - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - - retriesLeft -= 1 - } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) - - logDebug("Got this guidePort from Tracker: " + gInfo.listenPort) - return gInfo - } - - def registerBroadcast(id: Long, gInfo: SourceInfo) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(REGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Send this tracker's information - oosST.writeObject(gInfo) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - def unregisterBroadcast(id: Long) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - // Helper method to convert an object to Array[BroadcastBlock] - def blockifyObject[IN](obj: IN): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - oos.writeObject(obj) - oos.close() - baos.close() - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - - var blockNum = (byteArray.length / BlockSize) - if (byteArray.length % BlockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock](blockNum) - var blockID = 0 - - for (i <- 0 until (byteArray.length, BlockSize)) { - val thisBlockSize = math.min(BlockSize, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) - - retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) - blockID += 1 - } - bais.close() - - var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - // Helper method to convert Array[BroadcastBlock] to object - def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], - totalBytes: Int, - totalBlocks: Int): OUT = { - - var retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * BlockSize, arrayOfBlocks(i).byteArray.length) - } - byteArrayToObject(retByteArray) - } - - private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) - } - val retVal = in.readObject.asInstanceOf[OUT] - in.close() - return retVal - } -} - -private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) -extends Serializable - -private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock], - totalBlocks: Int, - totalBytes: Int) -extends Serializable { - @transient var hasBlocks = 0 -} diff --git a/core/src/main/scala/spark/broadcast/SourceInfo.scala b/core/src/main/scala/spark/broadcast/SourceInfo.scala deleted file mode 100644 index b17ae63b5c..0000000000 --- a/core/src/main/scala/spark/broadcast/SourceInfo.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.util.BitSet - -import spark._ - -/** - * Used to keep and pass around information of peers involved in a broadcast - */ -private[spark] case class SourceInfo (hostAddress: String, - listenPort: Int, - totalBlocks: Int = SourceInfo.UnusedParam, - totalBytes: Int = SourceInfo.UnusedParam) -extends Comparable[SourceInfo] with Logging { - - var currentLeechers = 0 - var receptionFailed = false - - var hasBlocks = 0 - var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) - - // Ascending sort based on leecher count - def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) -} - -/** - * Helper Object of SourceInfo for its constants - */ -private[spark] object SourceInfo { - // Broadcast has not started yet! Should never happen. - val TxNotStartedRetry = -1 - // Broadcast has already finished. Try default mechanism. - val TxOverGoToDefault = -3 - // Other constants - val StopBroadcast = -2 - val UnusedParam = 0 -} diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala deleted file mode 100644 index ea1e9a12c1..0000000000 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ /dev/null @@ -1,602 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.io._ -import java.net._ -import java.util.{Comparator, Random, UUID} - -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math - -import spark._ -import spark.storage.StorageLevel - -private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) -extends Broadcast[T](id) with Logging with Serializable { - - def value = value_ - - def blockId = "broadcast_" + id - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = 0 - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - @transient var hasBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - @transient var guideMR: GuideMultipleRequests = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - listOfSources += masterSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because Driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - private def initializeWorkerVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - - listenPortLock = new Object - totalBlocksLock = new Object - hasBlocksLock = new Object - - serveMR = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - stopBroadcast = false - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - var clientSocketToDriver: Socket = null - var oosDriver: ObjectOutputStream = null - var oisDriver: ObjectInputStream = null - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = MultiTracker.MaxRetryCount - do { - // Connect to Driver and send this worker's Information - clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort) - oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream) - oosDriver.flush() - oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream) - - logDebug("Connected to Driver's guiding object") - - // Send local source information - oosDriver.writeObject(SourceInfo(hostAddress, listenPort)) - oosDriver.flush() - - // Receive source information from Driver - var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = sourceInfo.totalBytes - - logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort) - - val start = System.nanoTime - val receptionSucceeded = receiveSingleTransmission(sourceInfo) - val time = (System.nanoTime - start) / 1e9 - - // Updating some statistics in sourceInfo. Driver will be using them later - if (!receptionSucceeded) { - sourceInfo.receptionFailed = true - } - - // Send back statistics to the Driver - oosDriver.writeObject(sourceInfo) - - if (oisDriver != null) { - oisDriver.close() - } - if (oosDriver != null) { - oosDriver.close() - } - if (clientSocketToDriver != null) { - clientSocketToDriver.close() - } - - retriesLeft -= 1 - } while (retriesLeft > 0 && hasBlocks < totalBlocks) - - return (hasBlocks == totalBlocks) - } - - /** - * Tries to receive broadcast from the source and returns Boolean status. - * This might be called multiple times to retry a defined number of times. - */ - private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { - var clientSocketToSource: Socket = null - var oosSource: ObjectOutputStream = null - var oisSource: ObjectInputStream = null - - var receptionSucceeded = false - try { - // Connect to the source to get the object itself - clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream) - oosSource.flush() - oisSource = new ObjectInputStream(clientSocketToSource.getInputStream) - - logDebug("Inside receiveSingleTransmission") - logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) - - // Send the range - oosSource.writeObject((hasBlocks, totalBlocks)) - oosSource.flush() - - for (i <- hasBlocks until totalBlocks) { - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") - - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - - // Set to true if at least one block is received - receptionSucceeded = true - hasBlocksLock.synchronized { hasBlocksLock.notifyAll() } - } - } catch { - case e: Exception => logError("receiveSingleTransmission had a " + e) - } finally { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (clientSocketToSource != null) { - clientSocketToSource.close() - } - } - - return receptionSucceeded - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close() the socket here; else, the thread will close() it - case ioe: IOException => clientSocket.close() - } - } - } - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - var listIter = listOfSources.iterator - while (listIter.hasNext) { - var sourceInfo = listIter.next - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Send stopBroadcast signal - gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. Other fields are invalid (SourceInfo.UnusedParam) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource(sourceInfo) - logDebug("Sending selectedSourceInfo: " + selectedSourceInfo) - oos.writeObject(selectedSourceInfo) - oos.flush() - - // Add this new (if it can finish) source to the list of sources - thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes) - logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo) - listOfSources += thisWorkerInfo - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in listOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // This should work since SourceInfo is a case class - assert(listOfSources.contains(selectedSourceInfo)) - - // Remove first - // (Currently removing a source based on just one failure notification!) - listOfSources = listOfSources - selectedSourceInfo - - // Update sourceInfo and put it back in, IF reception succeeded - if (!sourceInfo.receptionFailed) { - // Add thisWorkerInfo to sources that have completed reception - setOfCompletedSources.synchronized { - setOfCompletedSources += thisWorkerInfo - } - - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - } - } catch { - case e: Exception => { - // Remove failed worker from listOfSources and update leecherCount of - // corresponding source worker - listOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - listOfSources = listOfSources - selectedSourceInfo - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - - // Remove thisWorkerInfo - if (listOfSources != null) { - listOfSources = listOfSources - thisWorkerInfo - } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Assuming the caller to have a synchronized block on listOfSources - // Select one with the most leechers. This will level-wise fill the tree - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - var maxLeechers = -1 - var selectedSource: SourceInfo = null - - listOfSources.foreach { source => - if ((source.hostAddress != skipSourceInfo.hostAddress || - source.listenPort != skipSourceInfo.listenPort) && - source.currentLeechers < MultiTracker.MaxDegree && - source.currentLeechers > maxLeechers) { - selectedSource = source - maxLeechers = source.currentLeechers - } - } - - // Update leecher count - selectedSource.currentLeechers += 1 - return selectedSource - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - - var threadPool = Utils.newDaemonCachedThreadPool() - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { } - } - - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sendFrom = 0 - private var sendUntil = totalBlocks - - override def run() { - try { - logInfo("new ServeSingleRequest is running") - - // Receive range to send - var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] - sendFrom = rangeToSend._1 - sendUntil = rangeToSend._2 - - // If not a valid range, stop broadcast - if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - sendObject - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendObject() { - // Wait till receiving the SourceInfo from Driver - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - for (i <- sendFrom until sendUntil) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { hasBlocksLock.wait() } - } - try { - oos.writeObject(arrayOfBlocks(i)) - oos.flush() - } catch { - case e: Exception => logError("sendObject had a " + e) - } - logDebug("Sent block: " + i + " to " + clientSocket) - } - } - } - } -} - -private[spark] class TreeBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TreeBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala deleted file mode 100644 index a8b22fbef8..0000000000 --- a/core/src/main/scala/spark/deploy/ApplicationDescription.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -private[spark] class ApplicationDescription( - val name: String, - val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */ - val memoryPerSlave: Int, - val command: Command, - val sparkHome: String, - val appUiUrl: String) - extends Serializable { - - val user = System.getProperty("user.name", "") - - override def toString: String = "ApplicationDescription(" + name + ")" -} diff --git a/core/src/main/scala/spark/deploy/Command.scala b/core/src/main/scala/spark/deploy/Command.scala deleted file mode 100644 index bad629e965..0000000000 --- a/core/src/main/scala/spark/deploy/Command.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -import scala.collection.Map - -private[spark] case class Command( - mainClass: String, - arguments: Seq[String], - environment: Map[String, String]) { -} diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala deleted file mode 100644 index 0db13ffc98..0000000000 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -import scala.collection.immutable.List - -import spark.Utils -import spark.deploy.ExecutorState.ExecutorState -import spark.deploy.master.{WorkerInfo, ApplicationInfo} -import spark.deploy.worker.ExecutorRunner - - -private[deploy] sealed trait DeployMessage extends Serializable - -private[deploy] object DeployMessages { - - // Worker to Master - - case class RegisterWorker( - id: String, - host: String, - port: Int, - cores: Int, - memory: Int, - webUiPort: Int, - publicAddress: String) - extends DeployMessage { - Utils.checkHost(host, "Required hostname") - assert (port > 0) - } - - case class ExecutorStateChanged( - appId: String, - execId: Int, - state: ExecutorState, - message: Option[String], - exitStatus: Option[Int]) - extends DeployMessage - - case class Heartbeat(workerId: String) extends DeployMessage - - // Master to Worker - - case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage - - case class RegisterWorkerFailed(message: String) extends DeployMessage - - case class KillExecutor(appId: String, execId: Int) extends DeployMessage - - case class LaunchExecutor( - appId: String, - execId: Int, - appDesc: ApplicationDescription, - cores: Int, - memory: Int, - sparkHome: String) - extends DeployMessage - - // Client to Master - - case class RegisterApplication(appDescription: ApplicationDescription) - extends DeployMessage - - // Master to Client - - case class RegisteredApplication(appId: String) extends DeployMessage - - // TODO(matei): replace hostPort with host - case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { - Utils.checkHostPort(hostPort, "Required hostport") - } - - case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], - exitStatus: Option[Int]) - - case class ApplicationRemoved(message: String) - - // Internal message in Client - - case object StopClient - - // MasterWebUI To Master - - case object RequestMasterState - - // Master to MasterWebUI - - case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], - activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) { - - Utils.checkHost(host, "Required hostname") - assert (port > 0) - - def uri = "spark://" + host + ":" + port - } - - // WorkerWebUI to Worker - - case object RequestWorkerState - - // Worker to WorkerWebUI - - case class WorkerStateResponse(host: String, port: Int, workerId: String, - executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, - cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { - - Utils.checkHost(host, "Required hostname") - assert (port > 0) - } - - // Actor System to Master - - case object CheckForWorkerTimeOut - -} diff --git a/core/src/main/scala/spark/deploy/ExecutorState.scala b/core/src/main/scala/spark/deploy/ExecutorState.scala deleted file mode 100644 index 08c9a3b725..0000000000 --- a/core/src/main/scala/spark/deploy/ExecutorState.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -private[spark] object ExecutorState - extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") { - - val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value - - type ExecutorState = Value - - def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST).contains(state) -} diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala deleted file mode 100644 index f8dcf025b4..0000000000 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -import net.liftweb.json.JsonDSL._ - -import spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import spark.deploy.master.{ApplicationInfo, WorkerInfo} -import spark.deploy.worker.ExecutorRunner - - -private[spark] object JsonProtocol { - def writeWorkerInfo(obj: WorkerInfo) = { - ("id" -> obj.id) ~ - ("host" -> obj.host) ~ - ("port" -> obj.port) ~ - ("webuiaddress" -> obj.webUiAddress) ~ - ("cores" -> obj.cores) ~ - ("coresused" -> obj.coresUsed) ~ - ("memory" -> obj.memory) ~ - ("memoryused" -> obj.memoryUsed) ~ - ("state" -> obj.state.toString) - } - - def writeApplicationInfo(obj: ApplicationInfo) = { - ("starttime" -> obj.startTime) ~ - ("id" -> obj.id) ~ - ("name" -> obj.desc.name) ~ - ("cores" -> obj.desc.maxCores) ~ - ("user" -> obj.desc.user) ~ - ("memoryperslave" -> obj.desc.memoryPerSlave) ~ - ("submitdate" -> obj.submitDate.toString) - } - - def writeApplicationDescription(obj: ApplicationDescription) = { - ("name" -> obj.name) ~ - ("cores" -> obj.maxCores) ~ - ("memoryperslave" -> obj.memoryPerSlave) ~ - ("user" -> obj.user) - } - - def writeExecutorRunner(obj: ExecutorRunner) = { - ("id" -> obj.execId) ~ - ("memory" -> obj.memory) ~ - ("appid" -> obj.appId) ~ - ("appdesc" -> writeApplicationDescription(obj.appDesc)) - } - - def writeMasterState(obj: MasterStateResponse) = { - ("url" -> ("spark://" + obj.uri)) ~ - ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ - ("cores" -> obj.workers.map(_.cores).sum) ~ - ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ - ("memory" -> obj.workers.map(_.memory).sum) ~ - ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ - ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ - ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) - } - - def writeWorkerState(obj: WorkerStateResponse) = { - ("id" -> obj.workerId) ~ - ("masterurl" -> obj.masterUrl) ~ - ("masterwebuiurl" -> obj.masterWebUiUrl) ~ - ("cores" -> obj.cores) ~ - ("coresused" -> obj.coresUsed) ~ - ("memory" -> obj.memory) ~ - ("memoryused" -> obj.memoryUsed) ~ - ("executors" -> obj.executors.toList.map(writeExecutorRunner)) ~ - ("finishedexecutors" -> obj.finishedExecutors.toList.map(writeExecutorRunner)) - } -} diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala deleted file mode 100644 index 6b8e9f27af..0000000000 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} - -import spark.deploy.worker.Worker -import spark.deploy.master.Master -import spark.util.AkkaUtils -import spark.{Logging, Utils} - -import scala.collection.mutable.ArrayBuffer - -/** - * Testing class that creates a Spark standalone process in-cluster (that is, running the - * spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched - * by the Workers still run in separate JVMs. This can be used to test distributed operation and - * fault recovery without spinning up a lot of processes. - */ -private[spark] -class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { - - private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() - - def start(): String = { - logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") - - /* Start the Master */ - val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0) - masterActorSystems += masterSystem - val masterUrl = "spark://" + localHostname + ":" + masterPort - - /* Start the Workers */ - for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, - memoryPerWorker, masterUrl, null, Some(workerNum)) - workerActorSystems += workerSystem - } - - return masterUrl - } - - def stop() { - logInfo("Shutting down local Spark cluster.") - // Stop the workers before the master so they don't get upset that it disconnected - workerActorSystems.foreach(_.shutdown()) - workerActorSystems.foreach(_.awaitTermination()) - - masterActorSystems.foreach(_.shutdown()) - masterActorSystems.foreach(_.awaitTermination()) - } -} diff --git a/core/src/main/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/spark/deploy/SparkHadoopUtil.scala deleted file mode 100644 index 882161e669..0000000000 --- a/core/src/main/scala/spark/deploy/SparkHadoopUtil.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapred.JobConf - - -/** - * Contains util methods to interact with Hadoop from spark. - */ -class SparkHadoopUtil { - - // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems - def newConfiguration(): Configuration = new Configuration() - - // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster - def addCredentials(conf: JobConf) {} - - def isYarnMode(): Boolean = { false } - -} diff --git a/core/src/main/scala/spark/deploy/WebUI.scala b/core/src/main/scala/spark/deploy/WebUI.scala deleted file mode 100644 index 8ea7792ef4..0000000000 --- a/core/src/main/scala/spark/deploy/WebUI.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -import java.text.SimpleDateFormat -import java.util.Date - -/** - * Utilities used throughout the web UI. - */ -private[spark] object DeployWebUI { - val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - - def formatDate(date: Date): String = DATE_FORMAT.format(date) - - def formatDate(timestamp: Long): String = DATE_FORMAT.format(new Date(timestamp)) - - def formatDuration(milliseconds: Long): String = { - val seconds = milliseconds.toDouble / 1000 - if (seconds < 60) { - return "%.0f s".format(seconds) - } - val minutes = seconds / 60 - if (minutes < 10) { - return "%.1f min".format(minutes) - } else if (minutes < 60) { - return "%.0f min".format(minutes) - } - val hours = minutes / 60 - return "%.1f h".format(hours) - } -} diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala deleted file mode 100644 index 9d5ba8a796..0000000000 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.client - -import java.util.concurrent.TimeoutException - -import akka.actor._ -import akka.actor.Terminated -import akka.pattern.ask -import akka.util.Duration -import akka.remote.RemoteClientDisconnected -import akka.remote.RemoteClientLifeCycleEvent -import akka.remote.RemoteClientShutdown -import akka.dispatch.Await - -import spark.Logging -import spark.deploy.{ApplicationDescription, ExecutorState} -import spark.deploy.DeployMessages._ -import spark.deploy.master.Master - - -/** - * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description, - * and a listener for cluster events, and calls back the listener when various events occur. - */ -private[spark] class Client( - actorSystem: ActorSystem, - masterUrl: String, - appDescription: ApplicationDescription, - listener: ClientListener) - extends Logging { - - var actor: ActorRef = null - var appId: String = null - - class ClientActor extends Actor with Logging { - var master: ActorRef = null - var masterAddress: Address = null - var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times - - override def preStart() { - logInfo("Connecting to master " + masterUrl) - try { - master = context.actorFor(Master.toAkkaUrl(masterUrl)) - masterAddress = master.path.address - master ! RegisterApplication(appDescription) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing - } catch { - case e: Exception => - logError("Failed to connect to master", e) - markDisconnected() - context.stop(self) - } - } - - override def receive = { - case RegisteredApplication(appId_) => - appId = appId_ - listener.connected(appId) - - case ApplicationRemoved(message) => - logError("Master removed our application: %s; stopping client".format(message)) - markDisconnected() - context.stop(self) - - case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => - val fullId = appId + "/" + id - logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - listener.executorAdded(fullId, workerId, hostPort, cores, memory) - - case ExecutorUpdated(id, state, message, exitStatus) => - val fullId = appId + "/" + id - val messageText = message.map(s => " (" + s + ")").getOrElse("") - logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) - if (ExecutorState.isFinished(state)) { - listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) - } - - case Terminated(actor_) if actor_ == master => - logError("Connection to master failed; stopping client") - markDisconnected() - context.stop(self) - - case RemoteClientDisconnected(transport, address) if address == masterAddress => - logError("Connection to master failed; stopping client") - markDisconnected() - context.stop(self) - - case RemoteClientShutdown(transport, address) if address == masterAddress => - logError("Connection to master failed; stopping client") - markDisconnected() - context.stop(self) - - case StopClient => - markDisconnected() - sender ! true - context.stop(self) - } - - /** - * Notify the listener that we disconnected, if we hadn't already done so before. - */ - def markDisconnected() { - if (!alreadyDisconnected) { - listener.disconnected() - alreadyDisconnected = true - } - } - } - - def start() { - // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) - } - - def stop() { - if (actor != null) { - try { - val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - val future = actor.ask(StopClient)(timeout) - Await.result(future, timeout) - } catch { - case e: TimeoutException => - logInfo("Stop request to Master timed out; it may already be shut down.") - } - actor = null - } - } -} diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala deleted file mode 100644 index 064024455e..0000000000 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.client - -/** - * Callbacks invoked by deploy client when various events happen. There are currently four events: - * connecting to the cluster, disconnecting, being given an executor, and having an executor - * removed (either due to failure or due to revocation). - * - * Users of this API should *not* block inside the callback methods. - */ -private[spark] trait ClientListener { - def connected(appId: String): Unit - - def disconnected(): Unit - - def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit - - def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit -} diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala deleted file mode 100644 index 4f4daa141a..0000000000 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.client - -import spark.util.AkkaUtils -import spark.{Logging, Utils} -import spark.deploy.{Command, ApplicationDescription} - -private[spark] object TestClient { - - class TestListener extends ClientListener with Logging { - def connected(id: String) { - logInfo("Connected to master, got app ID " + id) - } - - def disconnected() { - logInfo("Disconnected from master") - System.exit(0) - } - - def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {} - - def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {} - } - - def main(args: Array[String]) { - val url = args(0) - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) - val desc = new ApplicationDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") - val listener = new TestListener - val client = new Client(actorSystem, url, desc, listener) - client.start() - actorSystem.awaitTermination() - } -} diff --git a/core/src/main/scala/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/spark/deploy/client/TestExecutor.scala deleted file mode 100644 index 8a22b6b89f..0000000000 --- a/core/src/main/scala/spark/deploy/client/TestExecutor.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.client - -private[spark] object TestExecutor { - def main(args: Array[String]) { - println("Hello world!") - while (true) { - Thread.sleep(1000) - } - } -} diff --git a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala deleted file mode 100644 index 6dd2f06126..0000000000 --- a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -import spark.deploy.ApplicationDescription -import java.util.Date -import akka.actor.ActorRef -import scala.collection.mutable - -private[spark] class ApplicationInfo( - val startTime: Long, - val id: String, - val desc: ApplicationDescription, - val submitDate: Date, - val driver: ActorRef, - val appUiUrl: String) -{ - var state = ApplicationState.WAITING - var executors = new mutable.HashMap[Int, ExecutorInfo] - var coresGranted = 0 - var endTime = -1L - val appSource = new ApplicationSource(this) - - private var nextExecutorId = 0 - - def newExecutorId(): Int = { - val id = nextExecutorId - nextExecutorId += 1 - id - } - - def addExecutor(worker: WorkerInfo, cores: Int): ExecutorInfo = { - val exec = new ExecutorInfo(newExecutorId(), this, worker, cores, desc.memoryPerSlave) - executors(exec.id) = exec - coresGranted += cores - exec - } - - def removeExecutor(exec: ExecutorInfo) { - if (executors.contains(exec.id)) { - executors -= exec.id - coresGranted -= exec.cores - } - } - - def coresLeft: Int = desc.maxCores - coresGranted - - private var _retryCount = 0 - - def retryCount = _retryCount - - def incrementRetryCount = { - _retryCount += 1 - _retryCount - } - - def markFinished(endState: ApplicationState.Value) { - state = endState - endTime = System.currentTimeMillis() - } - - def duration: Long = { - if (endTime != -1) { - endTime - startTime - } else { - System.currentTimeMillis() - startTime - } - } - -} diff --git a/core/src/main/scala/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/spark/deploy/master/ApplicationSource.scala deleted file mode 100644 index 4df2b6bfdd..0000000000 --- a/core/src/main/scala/spark/deploy/master/ApplicationSource.scala +++ /dev/null @@ -1,24 +0,0 @@ -package spark.deploy.master - -import com.codahale.metrics.{Gauge, MetricRegistry} - -import spark.metrics.source.Source - -class ApplicationSource(val application: ApplicationInfo) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.%s.%s".format("application", application.desc.name, - System.currentTimeMillis()) - - metricRegistry.register(MetricRegistry.name("status"), new Gauge[String] { - override def getValue: String = application.state.toString - }) - - metricRegistry.register(MetricRegistry.name("runtime_ms"), new Gauge[Long] { - override def getValue: Long = application.duration - }) - - metricRegistry.register(MetricRegistry.name("cores", "number"), new Gauge[Int] { - override def getValue: Int = application.coresGranted - }) - -} diff --git a/core/src/main/scala/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/spark/deploy/master/ApplicationState.scala deleted file mode 100644 index 94f0ad8bae..0000000000 --- a/core/src/main/scala/spark/deploy/master/ApplicationState.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -private[spark] object ApplicationState - extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") { - - type ApplicationState = Value - - val WAITING, RUNNING, FINISHED, FAILED = Value - - val MAX_NUM_RETRY = 10 -} diff --git a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala deleted file mode 100644 index 99b60f7d09..0000000000 --- a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -import spark.deploy.ExecutorState - -private[spark] class ExecutorInfo( - val id: Int, - val application: ApplicationInfo, - val worker: WorkerInfo, - val cores: Int, - val memory: Int) { - - var state = ExecutorState.LAUNCHING - - def fullId: String = application.id + "/" + id -} diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala deleted file mode 100644 index 04af5e149c..0000000000 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ /dev/null @@ -1,386 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -import java.text.SimpleDateFormat -import java.util.Date - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} - -import akka.actor._ -import akka.actor.Terminated -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown} -import akka.util.duration._ - -import spark.{Logging, SparkException, Utils} -import spark.deploy.{ApplicationDescription, ExecutorState} -import spark.deploy.DeployMessages._ -import spark.deploy.master.ui.MasterWebUI -import spark.metrics.MetricsSystem -import spark.util.AkkaUtils - - -private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { - val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 - val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt - val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt - - var nextAppNumber = 0 - val workers = new HashSet[WorkerInfo] - val idToWorker = new HashMap[String, WorkerInfo] - val actorToWorker = new HashMap[ActorRef, WorkerInfo] - val addressToWorker = new HashMap[Address, WorkerInfo] - - val apps = new HashSet[ApplicationInfo] - val idToApp = new HashMap[String, ApplicationInfo] - val actorToApp = new HashMap[ActorRef, ApplicationInfo] - val addressToApp = new HashMap[Address, ApplicationInfo] - - val waitingApps = new ArrayBuffer[ApplicationInfo] - val completedApps = new ArrayBuffer[ApplicationInfo] - - var firstApp: Option[ApplicationInfo] = None - - Utils.checkHost(host, "Expected hostname") - - val masterMetricsSystem = MetricsSystem.createMetricsSystem("master") - val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications") - val masterSource = new MasterSource(this) - - val webUi = new MasterWebUI(this, webUiPort) - - val masterPublicAddress = { - val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host - } - - // As a temporary workaround before better ways of configuring memory, we allow users to set - // a flag that will perform round-robin scheduling across the nodes (spreading out each app - // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. - val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean - - override def preStart() { - logInfo("Starting Spark master at spark://" + host + ":" + port) - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - webUi.start() - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) - - masterMetricsSystem.registerSource(masterSource) - masterMetricsSystem.start() - applicationMetricsSystem.start() - } - - override def postStop() { - webUi.stop() - masterMetricsSystem.stop() - applicationMetricsSystem.stop() - } - - override def receive = { - case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => { - logInfo("Registering worker %s:%d with %d cores, %s RAM".format( - host, workerPort, cores, Utils.megabytesToString(memory))) - if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") - } else { - addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) - context.watch(sender) // This doesn't work with remote actors but helps for testing - sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUi.boundPort.get) - schedule() - } - } - - case RegisterApplication(description) => { - logInfo("Registering app " + description.name) - val app = addApplication(description, sender) - logInfo("Registered app " + description.name + " with ID " + app.id) - waitingApps += app - context.watch(sender) // This doesn't work with remote actors but helps for testing - sender ! RegisteredApplication(app.id) - schedule() - } - - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => { - val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId)) - execOption match { - case Some(exec) => { - exec.state = state - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) - if (ExecutorState.isFinished(state)) { - val appInfo = idToApp(appId) - // Remove this executor from the worker and app - logInfo("Removing executor " + exec.fullId + " because it is " + state) - appInfo.removeExecutor(exec) - exec.worker.removeExecutor(exec) - - // Only retry certain number of times so we don't go into an infinite loop. - if (appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) { - schedule() - } else { - logError("Application %s with ID %s failed %d times, removing it".format( - appInfo.desc.name, appInfo.id, appInfo.retryCount)) - removeApplication(appInfo, ApplicationState.FAILED) - } - } - } - case None => - logWarning("Got status update for unknown executor " + appId + "/" + execId) - } - } - - case Heartbeat(workerId) => { - idToWorker.get(workerId) match { - case Some(workerInfo) => - workerInfo.lastHeartbeat = System.currentTimeMillis() - case None => - logWarning("Got heartbeat from unregistered worker " + workerId) - } - } - - case Terminated(actor) => { - // The disconnected actor could've been either a worker or an app; remove whichever of - // those we have an entry for in the corresponding actor hashmap - actorToWorker.get(actor).foreach(removeWorker) - actorToApp.get(actor).foreach(finishApplication) - } - - case RemoteClientDisconnected(transport, address) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - } - - case RemoteClientShutdown(transport, address) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - } - - case RequestMasterState => { - sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray) - } - - case CheckForWorkerTimeOut => { - timeOutDeadWorkers() - } - } - - /** - * Can an app use the given worker? True if the worker has enough memory and we haven't already - * launched an executor for the app on it (right now the standalone backend doesn't like having - * two executors on the same worker). - */ - def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = { - worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app) - } - - /** - * Schedule the currently available resources among waiting apps. This method will be called - * every time a new app joins or resource availability changes. - */ - def schedule() { - // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app - // in the queue, then the second app, etc. - if (spreadOutApps) { - // Try to spread out each app among all the nodes, until it has all its cores - for (app <- waitingApps if app.coresLeft > 0) { - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(canUse(app, _)).sortBy(_.coresFree).reverse - val numUsable = usableWorkers.length - val assigned = new Array[Int](numUsable) // Number of cores to give on each node - var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) - var pos = 0 - while (toAssign > 0) { - if (usableWorkers(pos).coresFree - assigned(pos) > 0) { - toAssign -= 1 - assigned(pos) += 1 - } - pos = (pos + 1) % numUsable - } - // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable) { - if (assigned(pos) > 0) { - val exec = app.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec, app.desc.sparkHome) - app.state = ApplicationState.RUNNING - } - } - } - } else { - // Pack each app into as few nodes as possible until we've assigned all its cores - for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { - for (app <- waitingApps if app.coresLeft > 0) { - if (canUse(app, worker)) { - val coresToUse = math.min(worker.coresFree, app.coresLeft) - if (coresToUse > 0) { - val exec = app.addExecutor(worker, coresToUse) - launchExecutor(worker, exec, app.desc.sparkHome) - app.state = ApplicationState.RUNNING - } - } - } - } - } - } - - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) { - logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) - worker.addExecutor(exec) - worker.actor ! LaunchExecutor( - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) - } - - def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, - publicAddress: String): WorkerInfo = { - // There may be one or more refs to dead workers on this same node (w/ different ID's), - // remove them. - workers.filter { w => - (w.host == host && w.port == port) && (w.state == WorkerState.DEAD) - }.foreach { w => - workers -= w - } - val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) - workers += worker - idToWorker(worker.id) = worker - actorToWorker(sender) = worker - addressToWorker(sender.path.address) = worker - worker - } - - def removeWorker(worker: WorkerInfo) { - logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) - worker.setState(WorkerState.DEAD) - idToWorker -= worker.id - actorToWorker -= worker.actor - addressToWorker -= worker.actor.path.address - for (exec <- worker.executors.values) { - logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) - exec.application.removeExecutor(exec) - } - } - - def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { - val now = System.currentTimeMillis() - val date = new Date(now) - val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl) - applicationMetricsSystem.registerSource(app.appSource) - apps += app - idToApp(app.id) = app - actorToApp(driver) = app - addressToApp(driver.path.address) = app - if (firstApp == None) { - firstApp = Some(app) - } - val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray - if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) { - logWarning("Could not find any workers with enough memory for " + firstApp.get.id) - } - app - } - - def finishApplication(app: ApplicationInfo) { - removeApplication(app, ApplicationState.FINISHED) - } - - def removeApplication(app: ApplicationInfo, state: ApplicationState.Value) { - if (apps.contains(app)) { - logInfo("Removing app " + app.id) - apps -= app - idToApp -= app.id - actorToApp -= app.driver - addressToApp -= app.driver.path.address - if (completedApps.size >= RETAINED_APPLICATIONS) { - val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) - completedApps.take(toRemove).foreach( a => { - applicationMetricsSystem.removeSource(a.appSource) - }) - completedApps.trimStart(toRemove) - } - completedApps += app // Remember it in our history - waitingApps -= app - for (exec <- app.executors.values) { - exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(exec.application.id, exec.id) - exec.state = ExecutorState.KILLED - } - app.markFinished(state) - if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) - } - schedule() - } - } - - /** Generate a new app ID given a app's submission date */ - def newApplicationId(submitDate: Date): String = { - val appId = "app-%s-%04d".format(DATE_FORMAT.format(submitDate), nextAppNumber) - nextAppNumber += 1 - appId - } - - /** Check for, and remove, any timed-out workers */ - def timeOutDeadWorkers() { - // Copy the workers into an array so we don't modify the hashset while iterating through it - val currentTime = System.currentTimeMillis() - val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray - for (worker <- toRemove) { - if (worker.state != WorkerState.DEAD) { - logWarning("Removing %s because we got no heartbeat in %d seconds".format( - worker.id, WORKER_TIMEOUT/1000)) - removeWorker(worker) - } else { - if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) - workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it - } - } - } -} - -private[spark] object Master { - private val systemName = "sparkMaster" - private val actorName = "Master" - private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r - - def main(argStrings: Array[String]) { - val args = new MasterArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort) - actorSystem.awaitTermination() - } - - /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ - def toAkkaUrl(sparkUrl: String): String = { - sparkUrl match { - case sparkUrlRegex(host, port) => - "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName) - case _ => - throw new SparkException("Invalid master URL: " + sparkUrl) - } - } - - def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) - val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName) - (actorSystem, boundPort) - } -} diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala deleted file mode 100644 index 0ae0160767..0000000000 --- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -import spark.util.IntParam -import spark.Utils - -/** - * Command-line parser for the master. - */ -private[spark] class MasterArguments(args: Array[String]) { - var host = Utils.localHostName() - var port = 7077 - var webUiPort = 8080 - - // Check for settings in environment variables - if (System.getenv("SPARK_MASTER_HOST") != null) { - host = System.getenv("SPARK_MASTER_HOST") - } - if (System.getenv("SPARK_MASTER_PORT") != null) { - port = System.getenv("SPARK_MASTER_PORT").toInt - } - if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { - webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt - } - if (System.getProperty("master.ui.port") != null) { - webUiPort = System.getProperty("master.ui.port").toInt - } - - parse(args.toList) - - def parse(args: List[String]): Unit = args match { - case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) - host = value - parse(tail) - - case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) - host = value - parse(tail) - - case ("--port" | "-p") :: IntParam(value) :: tail => - port = value - parse(tail) - - case "--webui-port" :: IntParam(value) :: tail => - webUiPort = value - parse(tail) - - case ("--help" | "-h") :: tail => - printUsageAndExit(0) - - case Nil => {} - - case _ => - printUsageAndExit(1) - } - - /** - * Print usage and exit JVM with the given exit code. - */ - def printUsageAndExit(exitCode: Int) { - System.err.println( - "Usage: Master [options]\n" + - "\n" + - "Options:\n" + - " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + - " -h HOST, --host HOST Hostname to listen on\n" + - " -p PORT, --port PORT Port to listen on (default: 7077)\n" + - " --webui-port PORT Port for web UI (default: 8080)") - System.exit(exitCode) - } -} diff --git a/core/src/main/scala/spark/deploy/master/MasterSource.scala b/core/src/main/scala/spark/deploy/master/MasterSource.scala deleted file mode 100644 index b8cfa6a773..0000000000 --- a/core/src/main/scala/spark/deploy/master/MasterSource.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.deploy.master - -import com.codahale.metrics.{Gauge, MetricRegistry} - -import spark.metrics.source.Source - -private[spark] class MasterSource(val master: Master) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "master" - - // Gauge for worker numbers in cluster - metricRegistry.register(MetricRegistry.name("workers","number"), new Gauge[Int] { - override def getValue: Int = master.workers.size - }) - - // Gauge for application numbers in cluster - metricRegistry.register(MetricRegistry.name("apps", "number"), new Gauge[Int] { - override def getValue: Int = master.apps.size - }) - - // Gauge for waiting application numbers in cluster - metricRegistry.register(MetricRegistry.name("waitingApps", "number"), new Gauge[Int] { - override def getValue: Int = master.waitingApps.size - }) -} diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala deleted file mode 100644 index 4135cfeb28..0000000000 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -import akka.actor.ActorRef -import scala.collection.mutable -import spark.Utils - -private[spark] class WorkerInfo( - val id: String, - val host: String, - val port: Int, - val cores: Int, - val memory: Int, - val actor: ActorRef, - val webUiPort: Int, - val publicAddress: String) { - - Utils.checkHost(host, "Expected hostname") - assert (port > 0) - - var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info - var state: WorkerState.Value = WorkerState.ALIVE - var coresUsed = 0 - var memoryUsed = 0 - - var lastHeartbeat = System.currentTimeMillis() - - def coresFree: Int = cores - coresUsed - def memoryFree: Int = memory - memoryUsed - - def hostPort: String = { - assert (port > 0) - host + ":" + port - } - - def addExecutor(exec: ExecutorInfo) { - executors(exec.fullId) = exec - coresUsed += exec.cores - memoryUsed += exec.memory - } - - def removeExecutor(exec: ExecutorInfo) { - if (executors.contains(exec.fullId)) { - executors -= exec.fullId - coresUsed -= exec.cores - memoryUsed -= exec.memory - } - } - - def hasExecutor(app: ApplicationInfo): Boolean = { - executors.values.exists(_.application == app) - } - - def webUiAddress : String = { - "http://" + this.publicAddress + ":" + this.webUiPort - } - - def setState(state: WorkerState.Value) = { - this.state = state - } -} diff --git a/core/src/main/scala/spark/deploy/master/WorkerState.scala b/core/src/main/scala/spark/deploy/master/WorkerState.scala deleted file mode 100644 index 3e50b7748d..0000000000 --- a/core/src/main/scala/spark/deploy/master/WorkerState.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -private[spark] object WorkerState extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED") { - type WorkerState = Value - - val ALIVE, DEAD, DECOMMISSIONED = Value -} diff --git a/core/src/main/scala/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/spark/deploy/master/ui/ApplicationPage.scala deleted file mode 100644 index 2ad98f759c..0000000000 --- a/core/src/main/scala/spark/deploy/master/ui/ApplicationPage.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master.ui - -import scala.xml.Node - -import akka.dispatch.Await -import akka.pattern.ask -import akka.util.duration._ - -import javax.servlet.http.HttpServletRequest - -import net.liftweb.json.JsonAST.JValue - -import spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} -import spark.deploy.JsonProtocol -import spark.deploy.master.ExecutorInfo -import spark.ui.UIUtils -import spark.Utils - -private[spark] class ApplicationPage(parent: MasterWebUI) { - val master = parent.masterActorRef - implicit val timeout = parent.timeout - - /** Executor details for a particular application */ - def renderJson(request: HttpServletRequest): JValue = { - val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, 30 seconds) - val app = state.activeApps.find(_.id == appId).getOrElse({ - state.completedApps.find(_.id == appId).getOrElse(null) - }) - JsonProtocol.writeApplicationInfo(app) - } - - /** Executor details for a particular application */ - def render(request: HttpServletRequest): Seq[Node] = { - val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, 30 seconds) - val app = state.activeApps.find(_.id == appId).getOrElse({ - state.completedApps.find(_.id == appId).getOrElse(null) - }) - - val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") - val executors = app.executors.values.toSeq - val executorTable = UIUtils.listingTable(executorHeaders, executorRow, executors) - - val content = -
-
-
    -
  • ID: {app.id}
  • -
  • Name: {app.desc.name}
  • -
  • User: {app.desc.user}
  • -
  • Cores: - { - if (app.desc.maxCores == Integer.MAX_VALUE) { - "Unlimited (%s granted)".format(app.coresGranted) - } else { - "%s (%s granted, %s left)".format( - app.desc.maxCores, app.coresGranted, app.coresLeft) - } - } -
  • -
  • - Executor Memory: - {Utils.megabytesToString(app.desc.memoryPerSlave)} -
  • -
  • Submit Date: {app.submitDate}
  • -
  • State: {app.state}
  • -
  • Application Detail UI
  • -
-
-
- -
-
-

Executor Summary

- {executorTable} -
-
; - UIUtils.basicSparkPage(content, "Application: " + app.desc.name) - } - - def executorRow(executor: ExecutorInfo): Seq[Node] = { - - {executor.id} - - {executor.worker.id} - - {executor.cores} - {executor.memory} - {executor.state} - - stdout - stderr - - - } -} diff --git a/core/src/main/scala/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/spark/deploy/master/ui/IndexPage.scala deleted file mode 100644 index 093e523e23..0000000000 --- a/core/src/main/scala/spark/deploy/master/ui/IndexPage.scala +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master.ui - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import akka.dispatch.Await -import akka.pattern.ask -import akka.util.duration._ - -import net.liftweb.json.JsonAST.JValue - -import spark.Utils -import spark.deploy.DeployWebUI -import spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} -import spark.deploy.JsonProtocol -import spark.deploy.master.{ApplicationInfo, WorkerInfo} -import spark.ui.UIUtils - -private[spark] class IndexPage(parent: MasterWebUI) { - val master = parent.masterActorRef - implicit val timeout = parent.timeout - - def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, 30 seconds) - JsonProtocol.writeMasterState(state) - } - - /** Index view listing applications and executors */ - def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, 30 seconds) - - val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory") - val workers = state.workers.sortBy(_.id) - val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - - val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User", - "State", "Duration") - val activeApps = state.activeApps.sortBy(_.startTime).reverse - val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) - val completedApps = state.completedApps.sortBy(_.endTime).reverse - val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) - - val content = -
-
-
    -
  • URL: {state.uri}
  • -
  • Workers: {state.workers.size}
  • -
  • Cores: {state.workers.map(_.cores).sum} Total, - {state.workers.map(_.coresUsed).sum} Used
  • -
  • Memory: - {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total, - {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
  • -
  • Applications: - {state.activeApps.size} Running, - {state.completedApps.size} Completed
  • -
-
-
- -
-
-

Workers

- {workerTable} -
-
- -
-
-

Running Applications

- - {activeAppsTable} -
-
- -
-
-

Completed Applications

- {completedAppsTable} -
-
; - UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) - } - - def workerRow(worker: WorkerInfo): Seq[Node] = { - - - {worker.id} - - {worker.host}:{worker.port} - {worker.state} - {worker.cores} ({worker.coresUsed} Used) - - {Utils.megabytesToString(worker.memory)} - ({Utils.megabytesToString(worker.memoryUsed)} Used) - - - } - - - def appRow(app: ApplicationInfo): Seq[Node] = { - - - {app.id} - - - {app.desc.name} - - - {app.coresGranted} - - - {Utils.megabytesToString(app.desc.memoryPerSlave)} - - {DeployWebUI.formatDate(app.submitDate)} - {app.desc.user} - {app.state.toString} - {DeployWebUI.formatDuration(app.duration)} - - } -} diff --git a/core/src/main/scala/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/ui/MasterWebUI.scala deleted file mode 100644 index c91e1db9f2..0000000000 --- a/core/src/main/scala/spark/deploy/master/ui/MasterWebUI.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master.ui - -import akka.util.Duration - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} - -import spark.{Logging, Utils} -import spark.deploy.master.Master -import spark.ui.JettyUtils -import spark.ui.JettyUtils._ - -/** - * Web UI server for the standalone master. - */ -private[spark] -class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { - implicit val timeout = Duration.create( - System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - val host = Utils.localHostName() - val port = requestedPort - - val masterActorRef = master.self - - var server: Option[Server] = None - var boundPort: Option[Int] = None - - val applicationPage = new ApplicationPage(this) - val indexPage = new IndexPage(this) - - def start() { - try { - val (srv, bPort) = JettyUtils.startJettyServer("0.0.0.0", port, handlers) - server = Some(srv) - boundPort = Some(bPort) - logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get)) - } catch { - case e: Exception => - logError("Failed to create Master JettyUtils", e) - System.exit(1) - } - } - - val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++ - master.applicationMetricsSystem.getServletHandlers - - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)), - ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)), - ("/app", (request: HttpServletRequest) => applicationPage.render(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) - ) - - def stop() { - server.foreach(_.stop()) - } -} - -private[spark] object MasterWebUI { - val STATIC_RESOURCE_DIR = "spark/ui/static" -} diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala deleted file mode 100644 index 34665ce451..0000000000 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.worker - -import java.io._ -import java.lang.System.getenv - -import akka.actor.ActorRef - -import com.google.common.base.Charsets -import com.google.common.io.Files - -import spark.{Utils, Logging} -import spark.deploy.{ExecutorState, ApplicationDescription} -import spark.deploy.DeployMessages.ExecutorStateChanged - -/** - * Manages the execution of one executor process. - */ -private[spark] class ExecutorRunner( - val appId: String, - val execId: Int, - val appDesc: ApplicationDescription, - val cores: Int, - val memory: Int, - val worker: ActorRef, - val workerId: String, - val host: String, - val sparkHome: File, - val workDir: File) - extends Logging { - - val fullId = appId + "/" + execId - var workerThread: Thread = null - var process: Process = null - var shutdownHook: Thread = null - - private def getAppEnv(key: String): Option[String] = - appDesc.command.environment.get(key).orElse(Option(getenv(key))) - - def start() { - workerThread = new Thread("ExecutorRunner for " + fullId) { - override def run() { fetchAndRunExecutor() } - } - workerThread.start() - - // Shutdown hook that kills actors on shutdown. - shutdownHook = new Thread() { - override def run() { - if (process != null) { - logInfo("Shutdown hook killing child process.") - process.destroy() - process.waitFor() - } - } - } - Runtime.getRuntime.addShutdownHook(shutdownHook) - } - - /** Stop this executor runner, including killing the process it launched */ - def kill() { - if (workerThread != null) { - workerThread.interrupt() - workerThread = null - if (process != null) { - logInfo("Killing process!") - process.destroy() - process.waitFor() - } - worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None) - Runtime.getRuntime.removeShutdownHook(shutdownHook) - } - } - - /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ - def substituteVariables(argument: String): String = argument match { - case "{{EXECUTOR_ID}}" => execId.toString - case "{{HOSTNAME}}" => host - case "{{CORES}}" => cores.toString - case other => other - } - - def buildCommandSeq(): Seq[String] = { - val command = appDesc.command - val runner = getAppEnv("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") - // SPARK-698: do not call the run.cmd script, as process.destroy() - // fails to kill a process tree on Windows - Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ - command.arguments.map(substituteVariables) - } - - /** - * Attention: this must always be aligned with the environment variables in the run scripts and - * the way the JAVA_OPTS are assembled there. - */ - def buildJavaOpts(): Seq[String] = { - val libraryOpts = getAppEnv("SPARK_LIBRARY_PATH") - .map(p => List("-Djava.library.path=" + p)) - .getOrElse(Nil) - val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil) - val userOpts = getAppEnv("SPARK_JAVA_OPTS").map(Utils.splitCommandString).getOrElse(Nil) - val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M") - - // Figure out our classpath with the external compute-classpath script - val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" - val classPath = Utils.executeAndGetOutput( - Seq(sparkHome + "/bin/compute-classpath" + ext), - extraEnvironment=appDesc.command.environment) - - Seq("-cp", classPath) ++ libraryOpts ++ workerLocalOpts ++ userOpts ++ memoryOpts - } - - /** Spawn a thread that will redirect a given stream to a file */ - def redirectStream(in: InputStream, file: File) { - val out = new FileOutputStream(file, true) - new Thread("redirect output to " + file) { - override def run() { - try { - Utils.copyStream(in, out, true) - } catch { - case e: IOException => - logInfo("Redirection to " + file + " closed: " + e.getMessage) - } - } - }.start() - } - - /** - * Download and run the executor described in our ApplicationDescription - */ - def fetchAndRunExecutor() { - try { - // Create the executor's working directory - val executorDir = new File(workDir, appId + "/" + execId) - if (!executorDir.mkdirs()) { - throw new IOException("Failed to create directory " + executorDir) - } - - // Launch the process - val command = buildCommandSeq() - logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) - val builder = new ProcessBuilder(command: _*).directory(executorDir) - val env = builder.environment() - for ((key, value) <- appDesc.command.environment) { - env.put(key, value) - } - // In case we are running this from within the Spark Shell, avoid creating a "scala" - // parent process for the executor command - env.put("SPARK_LAUNCH_WITH_SCALA", "0") - process = builder.start() - - val header = "Spark Executor Command: %s\n%s\n\n".format( - command.mkString("\"", "\" \"", "\""), "=" * 40) - - // Redirect its stdout and stderr to files - val stdout = new File(executorDir, "stdout") - redirectStream(process.getInputStream, stdout) - - val stderr = new File(executorDir, "stderr") - Files.write(header, stderr, Charsets.UTF_8) - redirectStream(process.getErrorStream, stderr) - - // Wait for it to exit; this is actually a bad thing if it happens, because we expect to run - // long-lived processes only. However, in the future, we might restart the executor a few - // times on the same machine. - val exitCode = process.waitFor() - val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), - Some(exitCode)) - } catch { - case interrupted: InterruptedException => - logInfo("Runner thread for executor " + fullId + " interrupted") - - case e: Exception => { - logError("Error running executor", e) - if (process != null) { - process.destroy() - } - val message = e.getClass + ": " + e.getMessage - worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None) - } - } - } -} diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala deleted file mode 100644 index 053ac55226..0000000000 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.worker - -import java.text.SimpleDateFormat -import java.util.Date -import java.io.File - -import scala.collection.mutable.HashMap - -import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} -import akka.util.duration._ - -import spark.{Logging, Utils} -import spark.deploy.ExecutorState -import spark.deploy.DeployMessages._ -import spark.deploy.master.Master -import spark.deploy.worker.ui.WorkerWebUI -import spark.metrics.MetricsSystem -import spark.util.AkkaUtils - - -private[spark] class Worker( - host: String, - port: Int, - webUiPort: Int, - cores: Int, - memory: Int, - masterUrl: String, - workDirPath: String = null) - extends Actor with Logging { - - Utils.checkHost(host, "Expected hostname") - assert (port > 0) - - val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs - - // Send a heartbeat every (heartbeat timeout) / 4 milliseconds - val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4 - - var master: ActorRef = null - var masterWebUiUrl : String = "" - val workerId = generateWorkerId() - var sparkHome: File = null - var workDir: File = null - val executors = new HashMap[String, ExecutorRunner] - val finishedExecutors = new HashMap[String, ExecutorRunner] - val publicAddress = { - val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host - } - var webUi: WorkerWebUI = null - - var coresUsed = 0 - var memoryUsed = 0 - - val metricsSystem = MetricsSystem.createMetricsSystem("worker") - val workerSource = new WorkerSource(this) - - def coresFree: Int = cores - coresUsed - def memoryFree: Int = memory - memoryUsed - - def createWorkDir() { - workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) - try { - // This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs() - // So attempting to create and then check if directory was created or not. - workDir.mkdirs() - if ( !workDir.exists() || !workDir.isDirectory) { - logError("Failed to create work directory " + workDir) - System.exit(1) - } - assert (workDir.isDirectory) - } catch { - case e: Exception => - logError("Failed to create work directory " + workDir, e) - System.exit(1) - } - } - - override def preStart() { - logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( - host, port, cores, Utils.megabytesToString(memory))) - sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) - logInfo("Spark home: " + sparkHome) - createWorkDir() - webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) - - webUi.start() - connectToMaster() - - metricsSystem.registerSource(workerSource) - metricsSystem.start() - } - - def connectToMaster() { - logInfo("Connecting to master " + masterUrl) - master = context.actorFor(Master.toAkkaUrl(masterUrl)) - master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing - } - - override def receive = { - case RegisteredWorker(url) => - masterWebUiUrl = url - logInfo("Successfully registered with master") - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) { - master ! Heartbeat(workerId) - } - - case RegisterWorkerFailed(message) => - logError("Worker registration failed: " + message) - System.exit(1) - - case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) => - logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) - val manager = new ExecutorRunner( - appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir) - executors(appId + "/" + execId) = manager - manager.start() - coresUsed += cores_ - memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None) - - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) - val fullId = appId + "/" + execId - if (ExecutorState.isFinished(state)) { - val executor = executors(fullId) - logInfo("Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - finishedExecutors(fullId) = executor - executors -= fullId - coresUsed -= executor.cores - memoryUsed -= executor.memory - } - - case KillExecutor(appId, execId) => - val fullId = appId + "/" + execId - executors.get(fullId) match { - case Some(executor) => - logInfo("Asked to kill executor " + fullId) - executor.kill() - case None => - logInfo("Asked to kill unknown executor " + fullId) - } - - case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => - masterDisconnected() - - case RequestWorkerState => { - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, masterUrl, cores, memory, - coresUsed, memoryUsed, masterWebUiUrl) - } - } - - def masterDisconnected() { - // TODO: It would be nice to try to reconnect to the master, but just shut down for now. - // (Note that if reconnecting we would also need to assign IDs differently.) - logError("Connection to master failed! Shutting down.") - executors.values.foreach(_.kill()) - System.exit(1) - } - - def generateWorkerId(): String = { - "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port) - } - - override def postStop() { - executors.values.foreach(_.kill()) - webUi.stop() - metricsSystem.stop() - } -} - -private[spark] object Worker { - def main(argStrings: Array[String]) { - val args = new WorkerArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, - args.memory, args.master, args.workDir) - actorSystem.awaitTermination() - } - - def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int, - masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { - // The LocalSparkCluster runs multiple local sparkWorkerX actor systems - val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) - val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory, - masterUrl, workDir)), name = "Worker") - (actorSystem, boundPort) - } - -} diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala deleted file mode 100644 index 9fcd3260ca..0000000000 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.worker - -import spark.util.IntParam -import spark.util.MemoryParam -import spark.Utils -import java.lang.management.ManagementFactory - -/** - * Command-line parser for the master. - */ -private[spark] class WorkerArguments(args: Array[String]) { - var host = Utils.localHostName() - var port = 0 - var webUiPort = 8081 - var cores = inferDefaultCores() - var memory = inferDefaultMemory() - var master: String = null - var workDir: String = null - - // Check for settings in environment variables - if (System.getenv("SPARK_WORKER_PORT") != null) { - port = System.getenv("SPARK_WORKER_PORT").toInt - } - if (System.getenv("SPARK_WORKER_CORES") != null) { - cores = System.getenv("SPARK_WORKER_CORES").toInt - } - if (System.getenv("SPARK_WORKER_MEMORY") != null) { - memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY")) - } - if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { - webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt - } - if (System.getenv("SPARK_WORKER_DIR") != null) { - workDir = System.getenv("SPARK_WORKER_DIR") - } - - parse(args.toList) - - def parse(args: List[String]): Unit = args match { - case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) - host = value - parse(tail) - - case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) - host = value - parse(tail) - - case ("--port" | "-p") :: IntParam(value) :: tail => - port = value - parse(tail) - - case ("--cores" | "-c") :: IntParam(value) :: tail => - cores = value - parse(tail) - - case ("--memory" | "-m") :: MemoryParam(value) :: tail => - memory = value - parse(tail) - - case ("--work-dir" | "-d") :: value :: tail => - workDir = value - parse(tail) - - case "--webui-port" :: IntParam(value) :: tail => - webUiPort = value - parse(tail) - - case ("--help" | "-h") :: tail => - printUsageAndExit(0) - - case value :: tail => - if (master != null) { // Two positional arguments were given - printUsageAndExit(1) - } - master = value - parse(tail) - - case Nil => - if (master == null) { // No positional argument was given - printUsageAndExit(1) - } - - case _ => - printUsageAndExit(1) - } - - /** - * Print usage and exit JVM with the given exit code. - */ - def printUsageAndExit(exitCode: Int) { - System.err.println( - "Usage: Worker [options] \n" + - "\n" + - "Master must be a URL of the form spark://hostname:port\n" + - "\n" + - "Options:\n" + - " -c CORES, --cores CORES Number of cores to use\n" + - " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + - " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" + - " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + - " -h HOST, --host HOST Hostname to listen on\n" + - " -p PORT, --port PORT Port to listen on (default: random)\n" + - " --webui-port PORT Port for web UI (default: 8081)") - System.exit(exitCode) - } - - def inferDefaultCores(): Int = { - Runtime.getRuntime.availableProcessors() - } - - def inferDefaultMemory(): Int = { - val ibmVendor = System.getProperty("java.vendor").contains("IBM") - var totalMb = 0 - try { - val bean = ManagementFactory.getOperatingSystemMXBean() - if (ibmVendor) { - val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") - val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory") - totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt - } else { - val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean") - val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") - totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt - } - } catch { - case e: Exception => { - totalMb = 2*1024 - System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") - } - } - // Leave out 1 GB for the operating system, but don't return a negative memory size - math.max(totalMb - 1024, 512) - } -} diff --git a/core/src/main/scala/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/spark/deploy/worker/WorkerSource.scala deleted file mode 100644 index 39cb8e5690..0000000000 --- a/core/src/main/scala/spark/deploy/worker/WorkerSource.scala +++ /dev/null @@ -1,34 +0,0 @@ -package spark.deploy.worker - -import com.codahale.metrics.{Gauge, MetricRegistry} - -import spark.metrics.source.Source - -private[spark] class WorkerSource(val worker: Worker) extends Source { - val sourceName = "worker" - val metricRegistry = new MetricRegistry() - - metricRegistry.register(MetricRegistry.name("executors", "number"), new Gauge[Int] { - override def getValue: Int = worker.executors.size - }) - - // Gauge for cores used of this worker - metricRegistry.register(MetricRegistry.name("coresUsed", "number"), new Gauge[Int] { - override def getValue: Int = worker.coresUsed - }) - - // Gauge for memory used of this worker - metricRegistry.register(MetricRegistry.name("memUsed", "MBytes"), new Gauge[Int] { - override def getValue: Int = worker.memoryUsed - }) - - // Gauge for cores free of this worker - metricRegistry.register(MetricRegistry.name("coresFree", "number"), new Gauge[Int] { - override def getValue: Int = worker.coresFree - }) - - // Gauge for memory free of this worker - metricRegistry.register(MetricRegistry.name("memFree", "MBytes"), new Gauge[Int] { - override def getValue: Int = worker.memoryFree - }) -} diff --git a/core/src/main/scala/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/spark/deploy/worker/ui/IndexPage.scala deleted file mode 100644 index 243e0765cb..0000000000 --- a/core/src/main/scala/spark/deploy/worker/ui/IndexPage.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.worker.ui - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import akka.dispatch.Await -import akka.pattern.ask -import akka.util.duration._ - -import net.liftweb.json.JsonAST.JValue - -import spark.Utils -import spark.deploy.JsonProtocol -import spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} -import spark.deploy.worker.ExecutorRunner -import spark.ui.UIUtils - - -private[spark] class IndexPage(parent: WorkerWebUI) { - val workerActor = parent.worker.self - val worker = parent.worker - val timeout = parent.timeout - - def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, 30 seconds) - JsonProtocol.writeWorkerState(workerState) - } - - def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, 30 seconds) - - val executorHeaders = Seq("ExecutorID", "Cores", "Memory", "Job Details", "Logs") - val runningExecutorTable = - UIUtils.listingTable(executorHeaders, executorRow, workerState.executors) - val finishedExecutorTable = - UIUtils.listingTable(executorHeaders, executorRow, workerState.finishedExecutors) - - val content = -
-
-
    -
  • ID: {workerState.workerId}
  • -
  • - Master URL: {workerState.masterUrl} -
  • -
  • Cores: {workerState.cores} ({workerState.coresUsed} Used)
  • -
  • Memory: {Utils.megabytesToString(workerState.memory)} - ({Utils.megabytesToString(workerState.memoryUsed)} Used)
  • -
-

Back to Master

-
-
- -
-
-

Running Executors {workerState.executors.size}

- {runningExecutorTable} -
-
- -
-
-

Finished Executors

- {finishedExecutorTable} -
-
; - - UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( - workerState.host, workerState.port)) - } - - def executorRow(executor: ExecutorRunner): Seq[Node] = { - - {executor.execId} - {executor.cores} - - {Utils.megabytesToString(executor.memory)} - - -
    -
  • ID: {executor.appId}
  • -
  • Name: {executor.appDesc.name}
  • -
  • User: {executor.appDesc.user}
  • -
- - - stdout - stderr - - - } - -} diff --git a/core/src/main/scala/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/ui/WorkerWebUI.scala deleted file mode 100644 index 0a75ad8cf4..0000000000 --- a/core/src/main/scala/spark/deploy/worker/ui/WorkerWebUI.scala +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.worker.ui - -import akka.util.{Duration, Timeout} - -import java.io.{FileInputStream, File} - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} - -import spark.deploy.worker.Worker -import spark.{Utils, Logging} -import spark.ui.JettyUtils -import spark.ui.JettyUtils._ -import spark.ui.UIUtils - -/** - * Web UI server for the standalone worker. - */ -private[spark] -class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None) - extends Logging { - implicit val timeout = Timeout( - Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")) - val host = Utils.localHostName() - val port = requestedPort.getOrElse( - System.getProperty("worker.ui.port", WorkerWebUI.DEFAULT_PORT).toInt) - - var server: Option[Server] = None - var boundPort: Option[Int] = None - - val indexPage = new IndexPage(this) - - val metricsHandlers = worker.metricsSystem.getServletHandlers - - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)), - ("/log", (request: HttpServletRequest) => log(request)), - ("/logPage", (request: HttpServletRequest) => logPage(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) - ) - - def start() { - try { - val (srv, bPort) = JettyUtils.startJettyServer("0.0.0.0", port, handlers) - server = Some(srv) - boundPort = Some(bPort) - logInfo("Started Worker web UI at http://%s:%d".format(host, bPort)) - } catch { - case e: Exception => - logError("Failed to create Worker JettyUtils", e) - System.exit(1) - } - } - - def log(request: HttpServletRequest): String = { - val defaultBytes = 100 * 1024 - val appId = request.getParameter("appId") - val executorId = request.getParameter("executorId") - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType) - - val (startByte, endByte) = getByteRange(path, offset, byteLength) - val file = new File(path) - val logLength = file.length - - val pre = "==== Bytes %s-%s of %s of %s/%s/%s ====\n" - .format(startByte, endByte, logLength, appId, executorId, logType) - pre + Utils.offsetBytes(path, startByte, endByte) - } - - def logPage(request: HttpServletRequest): Seq[scala.xml.Node] = { - val defaultBytes = 100 * 1024 - val appId = request.getParameter("appId") - val executorId = request.getParameter("executorId") - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType) - - val (startByte, endByte) = getByteRange(path, offset, byteLength) - val file = new File(path) - val logLength = file.length - - val logText = {Utils.offsetBytes(path, startByte, endByte)} - - val linkToMaster =

Back to Master

- - val range = Bytes {startByte.toString} - {endByte.toString} of {logLength} - - val backButton = - if (startByte > 0) { - - - - } - else { - - } - - val nextButton = - if (endByte < logLength) { - - - - } - else { - - } - - val content = - - - {linkToMaster} -
-
{backButton}
-
{range}
-
{nextButton}
-
-
-
-
{logText}
-
- - - UIUtils.basicSparkPage(content, logType + " log page for " + appId) - } - - /** Determine the byte range for a log or log page. */ - def getByteRange(path: String, offset: Option[Long], byteLength: Int) - : (Long, Long) = { - val defaultBytes = 100 * 1024 - val maxBytes = 1024 * 1024 - - val file = new File(path) - val logLength = file.length() - val getOffset = offset.getOrElse(logLength-defaultBytes) - - val startByte = - if (getOffset < 0) 0L - else if (getOffset > logLength) logLength - else getOffset - - val logPageLength = math.min(byteLength, maxBytes) - - val endByte = math.min(startByte+logPageLength, logLength) - - (startByte, endByte) - } - - def stop() { - server.foreach(_.stop()) - } -} - -private[spark] object WorkerWebUI { - val STATIC_RESOURCE_DIR = "spark/ui/static" - val DEFAULT_PORT="8081" -} diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala deleted file mode 100644 index fa82d2b324..0000000000 --- a/core/src/main/scala/spark/executor/Executor.scala +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -import java.io.{File} -import java.lang.management.ManagementFactory -import java.nio.ByteBuffer -import java.util.concurrent._ - -import scala.collection.JavaConversions._ -import scala.collection.mutable.HashMap - -import spark.scheduler._ -import spark._ - - -/** - * The Mesos executor for Spark. - */ -private[spark] class Executor( - executorId: String, - slaveHostname: String, - properties: Seq[(String, String)]) - extends Logging -{ - // Application dependencies (added through SparkContext) that we've fetched so far on this node. - // Each map holds the master's timestamp for the version of that file or JAR we got. - private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() - private val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - - private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) - - initLogging() - - // No ip or host:port - just hostname - Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") - // must not have port specified. - assert (0 == Utils.parseHostPort(slaveHostname)._2) - - // Make sure the local hostname we report matches the cluster scheduler's name for this host - Utils.setCustomHostname(slaveHostname) - - // Set spark.* system properties from executor arg - for ((key, value) <- properties) { - System.setProperty(key, value) - } - - // If we are in yarn mode, systems can have different disk layouts so we must set it - // to what Yarn on this system said was available. This will be used later when SparkEnv - // created. - if (java.lang.Boolean.valueOf(System.getenv("SPARK_YARN_MODE"))) { - System.setProperty("spark.local.dir", getYarnLocalDirs()) - } - - // Create our ClassLoader and set it on this thread - private val urlClassLoader = createClassLoader() - private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) - Thread.currentThread.setContextClassLoader(replClassLoader) - - // Make any thread terminations due to uncaught exceptions kill the entire - // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler( - new Thread.UncaughtExceptionHandler { - override def uncaughtException(thread: Thread, exception: Throwable) { - try { - logError("Uncaught exception in thread " + thread, exception) - - // We may have been called from a shutdown hook. If so, we must not call System.exit(). - // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(ExecutorExitCode.OOM) - } else { - System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) - } - } - } catch { - case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) - case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) - } - } - } - ) - - val executorSource = new ExecutorSource(this) - - // Initialize Spark environment (using system properties read above) - val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) - SparkEnv.set(env) - env.metricsSystem.registerSource(executorSource) - - private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") - - // Start worker thread pool - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) - - def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { - threadPool.execute(new TaskRunner(context, taskId, serializedTask)) - } - - /** Get the Yarn approved local directories. */ - private def getYarnLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(System.getenv("LOCAL_DIRS")) - .getOrElse("")) - - if (localDirs.isEmpty()) { - throw new Exception("Yarn Local dirs can't be empty") - } - return localDirs - } - - class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) - extends Runnable { - - override def run() { - val startTime = System.currentTimeMillis() - SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(replClassLoader) - val ser = SparkEnv.get.closureSerializer.newInstance() - logInfo("Running task ID " + taskId) - context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var attemptedTask: Option[Task[Any]] = None - var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum - val startGCTime = getTotalGCTime - - try { - SparkEnv.set(env) - Accumulators.clear() - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) - updateDependencies(taskFiles, taskJars) - val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) - attemptedTask = Some(task) - logInfo("Its epoch is " + task.epoch) - env.mapOutputTracker.updateEpoch(task.epoch) - taskStart = System.currentTimeMillis() - val value = task.run(taskId.toInt) - val taskFinish = System.currentTimeMillis() - for (m <- task.metrics) { - m.hostname = Utils.localHostName - m.executorDeserializeTime = (taskStart - startTime).toInt - m.executorRunTime = (taskFinish - taskStart).toInt - m.jvmGCTime = getTotalGCTime - startGCTime - } - //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c - // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could - // just change the relevants bytes in the byte buffer - val accumUpdates = Accumulators.values - val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null)) - val serializedResult = ser.serialize(result) - logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit) - if (serializedResult.limit >= (akkaFrameSize - 1024)) { - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure())) - return - } - context.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - logInfo("Finished task ID " + taskId) - } catch { - case ffe: FetchFailedException => { - val reason = ffe.toTaskEndReason - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - } - - case t: Throwable => { - val serviceTime = (System.currentTimeMillis() - taskStart).toInt - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.executorRunTime = serviceTime - m.jvmGCTime = getTotalGCTime - startGCTime - } - val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - - // TODO: Should we exit the whole executor here? On the one hand, the failed task may - // have left some weird state around depending on when the exception was thrown, but on - // the other hand, maybe we could detect that when future tasks fail and exit then. - logError("Exception in task ID " + taskId, t) - //System.exit(1) - } - } - } - } - - /** - * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes - * created by the interpreter to the search path - */ - private def createClassLoader(): ExecutorURLClassLoader = { - var loader = this.getClass.getClassLoader - - // For each of the jars in the jarSet, add them to the class loader. - // We assume each of the files has already been fetched. - val urls = currentJars.keySet.map { uri => - new File(uri.split("/").last).toURI.toURL - }.toArray - new ExecutorURLClassLoader(urls, loader) - } - - /** - * If the REPL is in use, add another ClassLoader that will read - * new classes defined by the REPL as the user types code - */ - private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { - val classUri = System.getProperty("spark.repl.class.uri") - if (classUri != null) { - logInfo("Using REPL class URI: " + classUri) - try { - val klass = Class.forName("spark.repl.ExecutorClassLoader") - .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) - return constructor.newInstance(classUri, parent) - } catch { - case _: ClassNotFoundException => - logError("Could not find spark.repl.ExecutorClassLoader on classpath!") - System.exit(1) - null - } - } else { - return parent - } - } - - /** - * Download any missing dependencies if we receive a new set of files and JARs from the - * SparkContext. Also adds any new JARs we fetched to the class loader. - */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - synchronized { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) - } - } - } - } -} diff --git a/core/src/main/scala/spark/executor/ExecutorBackend.scala b/core/src/main/scala/spark/executor/ExecutorBackend.scala deleted file mode 100644 index 33a6f8a824..0000000000 --- a/core/src/main/scala/spark/executor/ExecutorBackend.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -import java.nio.ByteBuffer -import spark.TaskState.TaskState - -/** - * A pluggable interface used by the Executor to send updates to the cluster scheduler. - */ -private[spark] trait ExecutorBackend { - def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) -} diff --git a/core/src/main/scala/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/spark/executor/ExecutorExitCode.scala deleted file mode 100644 index 64b9fb88f8..0000000000 --- a/core/src/main/scala/spark/executor/ExecutorExitCode.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -/** - * These are exit codes that executors should use to provide the master with information about - * executor failures assuming that cluster management framework can capture the exit codes (but - * perhaps not log files). The exit code constants here are chosen to be unlikely to conflict - * with "natural" exit statuses that may be caused by the JVM or user code. In particular, - * exit codes 128+ arise on some Unix-likes as a result of signals, and it appears that the - * OpenJDK JVM may use exit code 1 in some of its own "last chance" code. - */ -private[spark] -object ExecutorExitCode { - /** The default uncaught exception handler was reached. */ - val UNCAUGHT_EXCEPTION = 50 - - /** The default uncaught exception handler was called and an exception was encountered while - logging the exception. */ - val UNCAUGHT_EXCEPTION_TWICE = 51 - - /** The default uncaught exception handler was reached, and the uncaught exception was an - OutOfMemoryError. */ - val OOM = 52 - - /** DiskStore failed to create a local temporary directory after many attempts. */ - val DISK_STORE_FAILED_TO_CREATE_DIR = 53 - - def explainExitCode(exitCode: Int): String = { - exitCode match { - case UNCAUGHT_EXCEPTION => "Uncaught exception" - case UNCAUGHT_EXCEPTION_TWICE => "Uncaught exception, and logging the exception failed" - case OOM => "OutOfMemoryError" - case DISK_STORE_FAILED_TO_CREATE_DIR => - "Failed to create local directory (bad spark.local.dir?)" - case _ => - "Unknown executor exit code (" + exitCode + ")" + ( - if (exitCode > 128) - " (died from signal " + (exitCode - 128) + "?)" - else - "" - ) - } - } -} diff --git a/core/src/main/scala/spark/executor/ExecutorSource.scala b/core/src/main/scala/spark/executor/ExecutorSource.scala deleted file mode 100644 index d491a3c0c9..0000000000 --- a/core/src/main/scala/spark/executor/ExecutorSource.scala +++ /dev/null @@ -1,55 +0,0 @@ -package spark.executor - -import com.codahale.metrics.{Gauge, MetricRegistry} - -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.hdfs.DistributedFileSystem -import org.apache.hadoop.fs.LocalFileSystem - -import scala.collection.JavaConversions._ - -import spark.metrics.source.Source - -class ExecutorSource(val executor: Executor) extends Source { - private def fileStats(scheme: String) : Option[FileSystem.Statistics] = - FileSystem.getAllStatistics().filter(s => s.getScheme.equals(scheme)).headOption - - private def registerFileSystemStat[T]( - scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = { - metricRegistry.register(MetricRegistry.name("filesystem", scheme, name), new Gauge[T] { - override def getValue: T = fileStats(scheme).map(f).getOrElse(defaultValue) - }) - } - - val metricRegistry = new MetricRegistry() - val sourceName = "executor" - - // Gauge for executor thread pool's actively executing task counts - metricRegistry.register(MetricRegistry.name("threadpool", "activeTask", "count"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getActiveCount() - }) - - // Gauge for executor thread pool's approximate total number of tasks that have been completed - metricRegistry.register(MetricRegistry.name("threadpool", "completeTask", "count"), new Gauge[Long] { - override def getValue: Long = executor.threadPool.getCompletedTaskCount() - }) - - // Gauge for executor thread pool's current number of threads - metricRegistry.register(MetricRegistry.name("threadpool", "currentPool", "size"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getPoolSize() - }) - - // Gauge got executor thread pool's largest number of threads that have ever simultaneously been in th pool - metricRegistry.register(MetricRegistry.name("threadpool", "maxPool", "size"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getMaximumPoolSize() - }) - - // Gauge for file system stats of this executor - for (scheme <- Array("hdfs", "file")) { - registerFileSystemStat(scheme, "bytesRead", _.getBytesRead(), 0L) - registerFileSystemStat(scheme, "bytesWritten", _.getBytesWritten(), 0L) - registerFileSystemStat(scheme, "readOps", _.getReadOps(), 0) - registerFileSystemStat(scheme, "largeReadOps", _.getLargeReadOps(), 0) - registerFileSystemStat(scheme, "writeOps", _.getWriteOps(), 0) - } -} diff --git a/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala deleted file mode 100644 index 09d12fb65b..0000000000 --- a/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -import java.net.{URLClassLoader, URL} - -/** - * The addURL method in URLClassLoader is protected. We subclass it to make this accessible. - */ -private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends URLClassLoader(urls, parent) { - - override def addURL(url: URL) { - super.addURL(url) - } -} diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala deleted file mode 100644 index 4961c42fad..0000000000 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -import java.nio.ByteBuffer -import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} -import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _} -import spark.TaskState.TaskState -import com.google.protobuf.ByteString -import spark.{Utils, Logging} -import spark.TaskState - -private[spark] class MesosExecutorBackend - extends MesosExecutor - with ExecutorBackend - with Logging { - - var executor: Executor = null - var driver: ExecutorDriver = null - - override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - val mesosTaskId = TaskID.newBuilder().setValue(taskId.toString).build() - driver.sendStatusUpdate(MesosTaskStatus.newBuilder() - .setTaskId(mesosTaskId) - .setState(TaskState.toMesos(state)) - .setData(ByteString.copyFrom(data)) - .build()) - } - - override def registered( - driver: ExecutorDriver, - executorInfo: ExecutorInfo, - frameworkInfo: FrameworkInfo, - slaveInfo: SlaveInfo) { - logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) - this.driver = driver - val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) - executor = new Executor( - executorInfo.getExecutorId.getValue, - slaveInfo.getHostname, - properties) - } - - override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { - val taskId = taskInfo.getTaskId.getValue.toLong - if (executor == null) { - logError("Received launchTask but executor was null") - } else { - executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer) - } - } - - override def error(d: ExecutorDriver, message: String) { - logError("Error from Mesos: " + message) - } - - override def killTask(d: ExecutorDriver, t: TaskID) { - logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)") - } - - override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {} - - override def disconnected(d: ExecutorDriver) {} - - override def frameworkMessage(d: ExecutorDriver, data: Array[Byte]) {} - - override def shutdown(d: ExecutorDriver) {} -} - -/** - * Entry point for Mesos executor. - */ -private[spark] object MesosExecutorBackend { - def main(args: Array[String]) { - MesosNativeLibrary.load() - // Create a new Executor and start it running - val runner = new MesosExecutorBackend() - new MesosExecutorDriver(runner).run() - } -} diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala deleted file mode 100644 index b5fb6dbe29..0000000000 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -import java.nio.ByteBuffer - -import akka.actor.{ActorRef, Actor, Props, Terminated} -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} - -import spark.{Logging, Utils, SparkEnv} -import spark.TaskState.TaskState -import spark.scheduler.cluster.StandaloneClusterMessages._ -import spark.util.AkkaUtils - - -private[spark] class StandaloneExecutorBackend( - driverUrl: String, - executorId: String, - hostPort: String, - cores: Int) - extends Actor - with ExecutorBackend - with Logging { - - Utils.checkHostPort(hostPort, "Expected hostport") - - var executor: Executor = null - var driver: ActorRef = null - - override def preStart() { - logInfo("Connecting to driver: " + driverUrl) - driver = context.actorFor(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(driver) // Doesn't work with remote actors, but useful for testing - } - - override def receive = { - case RegisteredExecutor(sparkProperties) => - logInfo("Successfully registered with driver") - // Make this host instead of hostPort ? - executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties) - - case RegisterExecutorFailed(message) => - logError("Slave registration failed: " + message) - System.exit(1) - - case LaunchTask(taskDesc) => - logInfo("Got assigned task " + taskDesc.taskId) - if (executor == null) { - logError("Received launchTask but executor was null") - System.exit(1) - } else { - executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) - } - - case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => - logError("Driver terminated or disconnected! Shutting down.") - System.exit(1) - } - - override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) - } -} - -private[spark] object StandaloneExecutorBackend { - def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { - // Debug code - Utils.checkHost(hostname) - - // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor - // before getting started with all our system properties, etc - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) - // set it - val sparkHostPort = hostname + ":" + boundPort - System.setProperty("spark.hostPort", sparkHostPort) - val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), - name = "Executor") - actorSystem.awaitTermination() - } - - def main(args: Array[String]) { - if (args.length < 4) { - //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors - System.err.println("Usage: StandaloneExecutorBackend []") - System.exit(1) - } - run(args(0), args(1), args(2), args(3).toInt) - } -} diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala deleted file mode 100644 index 47b8890bee..0000000000 --- a/core/src/main/scala/spark/executor/TaskMetrics.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -class TaskMetrics extends Serializable { - /** - * Host's name the task runs on - */ - var hostname: String = _ - - /** - * Time taken on the executor to deserialize this task - */ - var executorDeserializeTime: Int = _ - - /** - * Time the executor spends actually running the task (including fetching shuffle data) - */ - var executorRunTime: Int = _ - - /** - * The number of bytes this task transmitted back to the driver as the TaskResult - */ - var resultSize: Long = _ - - /** - * Amount of time the JVM spent in garbage collection while executing this task - */ - var jvmGCTime: Long = _ - - /** - * If this task reads from shuffle output, metrics on getting shuffle data will be collected here - */ - var shuffleReadMetrics: Option[ShuffleReadMetrics] = None - - /** - * If this task writes to shuffle output, metrics on the written shuffle data will be collected here - */ - var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None -} - -object TaskMetrics { - private[spark] def empty(): TaskMetrics = new TaskMetrics -} - - -class ShuffleReadMetrics extends Serializable { - /** - * Time when shuffle finishs - */ - var shuffleFinishTime: Long = _ - - /** - * Total number of blocks fetched in a shuffle (remote or local) - */ - var totalBlocksFetched: Int = _ - - /** - * Number of remote blocks fetched in a shuffle - */ - var remoteBlocksFetched: Int = _ - - /** - * Local blocks fetched in a shuffle - */ - var localBlocksFetched: Int = _ - - /** - * Total time that is spent blocked waiting for shuffle to fetch data - */ - var fetchWaitTime: Long = _ - - /** - * The total amount of time for all the shuffle fetches. This adds up time from overlapping - * shuffles, so can be longer than task time - */ - var remoteFetchTime: Long = _ - - /** - * Total number of remote bytes read from a shuffle - */ - var remoteBytesRead: Long = _ -} - -class ShuffleWriteMetrics extends Serializable { - /** - * Number of bytes written for a shuffle - */ - var shuffleBytesWritten: Long = _ -} diff --git a/core/src/main/scala/spark/io/CompressionCodec.scala b/core/src/main/scala/spark/io/CompressionCodec.scala deleted file mode 100644 index 0adebecadb..0000000000 --- a/core/src/main/scala/spark/io/CompressionCodec.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.io - -import java.io.{InputStream, OutputStream} - -import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} - -import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream} - - -/** - * CompressionCodec allows the customization of choosing different compression implementations - * to be used in block storage. - */ -trait CompressionCodec { - - def compressedOutputStream(s: OutputStream): OutputStream - - def compressedInputStream(s: InputStream): InputStream -} - - -private[spark] object CompressionCodec { - - def createCodec(): CompressionCodec = { - // Set the default codec to Snappy since the LZF implementation initializes a pretty large - // buffer for every stream, which results in a lot of memory overhead when the number of - // shuffle reduce buckets are large. - createCodec(classOf[SnappyCompressionCodec].getName) - } - - def createCodec(codecName: String): CompressionCodec = { - Class.forName( - System.getProperty("spark.io.compression.codec", codecName), - true, - Thread.currentThread.getContextClassLoader).newInstance().asInstanceOf[CompressionCodec] - } -} - - -/** - * LZF implementation of [[spark.io.CompressionCodec]]. - */ -class LZFCompressionCodec extends CompressionCodec { - - override def compressedOutputStream(s: OutputStream): OutputStream = { - new LZFOutputStream(s).setFinishBlockOnFlush(true) - } - - override def compressedInputStream(s: InputStream): InputStream = new LZFInputStream(s) -} - - -/** - * Snappy implementation of [[spark.io.CompressionCodec]]. - * Block size can be configured by spark.io.compression.snappy.block.size. - */ -class SnappyCompressionCodec extends CompressionCodec { - - override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = System.getProperty("spark.io.compression.snappy.block.size", "32768").toInt - new SnappyOutputStream(s, blockSize) - } - - override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) -} diff --git a/core/src/main/scala/spark/metrics/MetricsConfig.scala b/core/src/main/scala/spark/metrics/MetricsConfig.scala deleted file mode 100644 index d7fb5378a4..0000000000 --- a/core/src/main/scala/spark/metrics/MetricsConfig.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics - -import java.util.Properties -import java.io.{File, FileInputStream, InputStream, IOException} - -import scala.collection.mutable -import scala.util.matching.Regex - -import spark.Logging - -private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { - initLogging() - - val DEFAULT_PREFIX = "*" - val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r - val METRICS_CONF = "metrics.properties" - - val properties = new Properties() - var propertyCategories: mutable.HashMap[String, Properties] = null - - private def setDefaultProperties(prop: Properties) { - prop.setProperty("*.sink.servlet.class", "spark.metrics.sink.MetricsServlet") - prop.setProperty("*.sink.servlet.uri", "/metrics/json") - prop.setProperty("*.sink.servlet.sample", "false") - prop.setProperty("master.sink.servlet.uri", "/metrics/master/json") - prop.setProperty("applications.sink.servlet.uri", "/metrics/applications/json") - } - - def initialize() { - //Add default properties in case there's no properties file - setDefaultProperties(properties) - - // If spark.metrics.conf is not set, try to get file in class path - var is: InputStream = null - try { - is = configFile match { - case Some(f) => new FileInputStream(f) - case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF) - } - - if (is != null) { - properties.load(is) - } - } catch { - case e: Exception => logError("Error loading configure file", e) - } finally { - if (is != null) is.close() - } - - propertyCategories = subProperties(properties, INSTANCE_REGEX) - if (propertyCategories.contains(DEFAULT_PREFIX)) { - import scala.collection.JavaConversions._ - - val defaultProperty = propertyCategories(DEFAULT_PREFIX) - for { (inst, prop) <- propertyCategories - if (inst != DEFAULT_PREFIX) - (k, v) <- defaultProperty - if (prop.getProperty(k) == null) } { - prop.setProperty(k, v) - } - } - } - - def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = { - val subProperties = new mutable.HashMap[String, Properties] - import scala.collection.JavaConversions._ - prop.foreach { kv => - if (regex.findPrefixOf(kv._1) != None) { - val regex(prefix, suffix) = kv._1 - subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2) - } - } - subProperties - } - - def getInstance(inst: String): Properties = { - propertyCategories.get(inst) match { - case Some(s) => s - case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties) - } - } -} - diff --git a/core/src/main/scala/spark/metrics/MetricsSystem.scala b/core/src/main/scala/spark/metrics/MetricsSystem.scala deleted file mode 100644 index 4e6c6b26c8..0000000000 --- a/core/src/main/scala/spark/metrics/MetricsSystem.scala +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics - -import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} - -import java.util.Properties -import java.util.concurrent.TimeUnit - -import scala.collection.mutable - -import spark.Logging -import spark.metrics.sink.{MetricsServlet, Sink} -import spark.metrics.source.Source - -/** - * Spark Metrics System, created by specific "instance", combined by source, - * sink, periodically poll source metrics data to sink destinations. - * - * "instance" specify "who" (the role) use metrics system. In spark there are several roles - * like master, worker, executor, client driver, these roles will create metrics system - * for monitoring. So instance represents these roles. Currently in Spark, several instances - * have already implemented: master, worker, executor, driver, applications. - * - * "source" specify "where" (source) to collect metrics data. In metrics system, there exists - * two kinds of source: - * 1. Spark internal source, like MasterSource, WorkerSource, etc, which will collect - * Spark component's internal state, these sources are related to instance and will be - * added after specific metrics system is created. - * 2. Common source, like JvmSource, which will collect low level state, is configured by - * configuration and loaded through reflection. - * - * "sink" specify "where" (destination) to output metrics data to. Several sinks can be - * coexisted and flush metrics to all these sinks. - * - * Metrics configuration format is like below: - * [instance].[sink|source].[name].[options] = xxxx - * - * [instance] can be "master", "worker", "executor", "driver", "applications" which means only - * the specified instance has this property. - * wild card "*" can be used to replace instance name, which means all the instances will have - * this property. - * - * [sink|source] means this property belongs to source or sink. This field can only be source or sink. - * - * [name] specify the name of sink or source, it is custom defined. - * - * [options] is the specific property of this source or sink. - */ -private[spark] class MetricsSystem private (val instance: String) extends Logging { - initLogging() - - val confFile = System.getProperty("spark.metrics.conf") - val metricsConfig = new MetricsConfig(Option(confFile)) - - val sinks = new mutable.ArrayBuffer[Sink] - val sources = new mutable.ArrayBuffer[Source] - val registry = new MetricRegistry() - - // Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui - private var metricsServlet: Option[MetricsServlet] = None - - /** Get any UI handlers used by this metrics system. */ - def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array()) - - metricsConfig.initialize() - registerSources() - registerSinks() - - def start() { - sinks.foreach(_.start) - } - - def stop() { - sinks.foreach(_.stop) - } - - def registerSource(source: Source) { - sources += source - try { - registry.register(source.sourceName, source.metricRegistry) - } catch { - case e: IllegalArgumentException => logInfo("Metrics already registered", e) - } - } - - def removeSource(source: Source) { - sources -= source - registry.removeMatching(new MetricFilter { - def matches(name: String, metric: Metric): Boolean = name.startsWith(source.sourceName) - }) - } - - def registerSources() { - val instConfig = metricsConfig.getInstance(instance) - val sourceConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SOURCE_REGEX) - - // Register all the sources related to instance - sourceConfigs.foreach { kv => - val classPath = kv._2.getProperty("class") - try { - val source = Class.forName(classPath).newInstance() - registerSource(source.asInstanceOf[Source]) - } catch { - case e: Exception => logError("Source class " + classPath + " cannot be instantialized", e) - } - } - } - - def registerSinks() { - val instConfig = metricsConfig.getInstance(instance) - val sinkConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SINK_REGEX) - - sinkConfigs.foreach { kv => - val classPath = kv._2.getProperty("class") - try { - val sink = Class.forName(classPath) - .getConstructor(classOf[Properties], classOf[MetricRegistry]) - .newInstance(kv._2, registry) - if (kv._1 == "servlet") { - metricsServlet = Some(sink.asInstanceOf[MetricsServlet]) - } else { - sinks += sink.asInstanceOf[Sink] - } - } catch { - case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e) - } - } - } -} - -private[spark] object MetricsSystem { - val SINK_REGEX = "^sink\\.(.+)\\.(.+)".r - val SOURCE_REGEX = "^source\\.(.+)\\.(.+)".r - - val MINIMAL_POLL_UNIT = TimeUnit.SECONDS - val MINIMAL_POLL_PERIOD = 1 - - def checkMinimalPollingPeriod(pollUnit: TimeUnit, pollPeriod: Int) { - val period = MINIMAL_POLL_UNIT.convert(pollPeriod, pollUnit) - if (period < MINIMAL_POLL_PERIOD) { - throw new IllegalArgumentException("Polling period " + pollPeriod + " " + pollUnit + - " below than minimal polling period ") - } - } - - def createMetricsSystem(instance: String): MetricsSystem = new MetricsSystem(instance) -} diff --git a/core/src/main/scala/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/spark/metrics/sink/ConsoleSink.scala deleted file mode 100644 index 966ba37c20..0000000000 --- a/core/src/main/scala/spark/metrics/sink/ConsoleSink.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.sink - -import com.codahale.metrics.{ConsoleReporter, MetricRegistry} - -import java.util.Properties -import java.util.concurrent.TimeUnit - -import spark.metrics.MetricsSystem - -class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink { - val CONSOLE_DEFAULT_PERIOD = 10 - val CONSOLE_DEFAULT_UNIT = "SECONDS" - - val CONSOLE_KEY_PERIOD = "period" - val CONSOLE_KEY_UNIT = "unit" - - val pollPeriod = Option(property.getProperty(CONSOLE_KEY_PERIOD)) match { - case Some(s) => s.toInt - case None => CONSOLE_DEFAULT_PERIOD - } - - val pollUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) - case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT) - } - - MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - - val reporter: ConsoleReporter = ConsoleReporter.forRegistry(registry) - .convertDurationsTo(TimeUnit.MILLISECONDS) - .convertRatesTo(TimeUnit.SECONDS) - .build() - - override def start() { - reporter.start(pollPeriod, pollUnit) - } - - override def stop() { - reporter.stop() - } -} - diff --git a/core/src/main/scala/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/spark/metrics/sink/CsvSink.scala deleted file mode 100644 index cb990afdef..0000000000 --- a/core/src/main/scala/spark/metrics/sink/CsvSink.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.sink - -import com.codahale.metrics.{CsvReporter, MetricRegistry} - -import java.io.File -import java.util.{Locale, Properties} -import java.util.concurrent.TimeUnit - -import spark.metrics.MetricsSystem - -class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink { - val CSV_KEY_PERIOD = "period" - val CSV_KEY_UNIT = "unit" - val CSV_KEY_DIR = "directory" - - val CSV_DEFAULT_PERIOD = 10 - val CSV_DEFAULT_UNIT = "SECONDS" - val CSV_DEFAULT_DIR = "/tmp/" - - val pollPeriod = Option(property.getProperty(CSV_KEY_PERIOD)) match { - case Some(s) => s.toInt - case None => CSV_DEFAULT_PERIOD - } - - val pollUnit = Option(property.getProperty(CSV_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) - case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT) - } - - MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - - val pollDir = Option(property.getProperty(CSV_KEY_DIR)) match { - case Some(s) => s - case None => CSV_DEFAULT_DIR - } - - val reporter: CsvReporter = CsvReporter.forRegistry(registry) - .formatFor(Locale.US) - .convertDurationsTo(TimeUnit.MILLISECONDS) - .convertRatesTo(TimeUnit.SECONDS) - .build(new File(pollDir)) - - override def start() { - reporter.start(pollPeriod, pollUnit) - } - - override def stop() { - reporter.stop() - } -} - diff --git a/core/src/main/scala/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/spark/metrics/sink/JmxSink.scala deleted file mode 100644 index ee04544c0e..0000000000 --- a/core/src/main/scala/spark/metrics/sink/JmxSink.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.sink - -import com.codahale.metrics.{JmxReporter, MetricRegistry} - -import java.util.Properties - -class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink { - val reporter: JmxReporter = JmxReporter.forRegistry(registry).build() - - override def start() { - reporter.start() - } - - override def stop() { - reporter.stop() - } - -} diff --git a/core/src/main/scala/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/spark/metrics/sink/MetricsServlet.scala deleted file mode 100644 index 17432b1ed1..0000000000 --- a/core/src/main/scala/spark/metrics/sink/MetricsServlet.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.sink - -import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.json.MetricsModule - -import com.fasterxml.jackson.databind.ObjectMapper - -import java.util.Properties -import java.util.concurrent.TimeUnit -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.Handler - -import spark.ui.JettyUtils - -class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink { - val SERVLET_KEY_URI = "uri" - val SERVLET_KEY_SAMPLE = "sample" - - val servletURI = property.getProperty(SERVLET_KEY_URI) - - val servletShowSample = property.getProperty(SERVLET_KEY_SAMPLE).toBoolean - - val mapper = new ObjectMapper().registerModule( - new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - - def getHandlers = Array[(String, Handler)]( - (servletURI, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json")) - ) - - def getMetricsSnapshot(request: HttpServletRequest): String = { - mapper.writeValueAsString(registry) - } - - override def start() { } - - override def stop() { } -} diff --git a/core/src/main/scala/spark/metrics/sink/Sink.scala b/core/src/main/scala/spark/metrics/sink/Sink.scala deleted file mode 100644 index dad1a7f0fe..0000000000 --- a/core/src/main/scala/spark/metrics/sink/Sink.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.sink - -trait Sink { - def start: Unit - def stop: Unit -} \ No newline at end of file diff --git a/core/src/main/scala/spark/metrics/source/JvmSource.scala b/core/src/main/scala/spark/metrics/source/JvmSource.scala deleted file mode 100644 index e771008557..0000000000 --- a/core/src/main/scala/spark/metrics/source/JvmSource.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.source - -import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet} - -class JvmSource extends Source { - val sourceName = "jvm" - val metricRegistry = new MetricRegistry() - - val gcMetricSet = new GarbageCollectorMetricSet - val memGaugeSet = new MemoryUsageGaugeSet - - metricRegistry.registerAll(gcMetricSet) - metricRegistry.registerAll(memGaugeSet) -} diff --git a/core/src/main/scala/spark/metrics/source/Source.scala b/core/src/main/scala/spark/metrics/source/Source.scala deleted file mode 100644 index 76199a004b..0000000000 --- a/core/src/main/scala/spark/metrics/source/Source.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.source - -import com.codahale.metrics.MetricRegistry - -trait Source { - def sourceName: String - def metricRegistry: MetricRegistry -} diff --git a/core/src/main/scala/spark/network/BufferMessage.scala b/core/src/main/scala/spark/network/BufferMessage.scala deleted file mode 100644 index e566aeac13..0000000000 --- a/core/src/main/scala/spark/network/BufferMessage.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import spark.storage.BlockManager - - -private[spark] -class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) - extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size = initialSize - - def currentSize() = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - if (size == 0 && gotChunkForSendingOnce == false) { - val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId() = (ackId != 0) - - def isCompletelyReceived() = !buffers(0).hasRemaining - - override def toString = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala deleted file mode 100644 index 1e571d39ae..0000000000 --- a/core/src/main/scala/spark/network/Connection.scala +++ /dev/null @@ -1,586 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import spark._ - -import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} - -import java.io._ -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.net._ - - -private[spark] -abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId) - extends Logging { - - def this(channel_ : SocketChannel, selector_ : Selector) = { - this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) - } - - channel.configureBlocking(false) - channel.socket.setTcpNoDelay(true) - channel.socket.setReuseAddress(true) - channel.socket.setKeepAlive(true) - /*channel.socket.setReceiveBufferSize(32768) */ - - @volatile private var closed = false - var onCloseCallback: Connection => Unit = null - var onExceptionCallback: (Connection, Exception) => Unit = null - var onKeyInterestChangeCallback: (Connection, Int) => Unit = null - - val remoteAddress = getRemoteAddress() - - def resetForceReregister(): Boolean - - // Read channels typically do not register for write and write does not for read - // Now, we do have write registering for read too (temporarily), but this is to detect - // channel close NOT to actually read/consume data on it ! - // How does this work if/when we move to SSL ? - - // What is the interest to register with selector for when we want this connection to be selected - def registerInterest() - - // What is the interest to register with selector for when we want this connection to - // be de-selected - // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, - // it will be SelectionKey.OP_READ (until we fix it properly) - def unregisterInterest() - - // On receiving a read event, should we change the interest for this channel or not ? - // Will be true for ReceivingConnection, false for SendingConnection. - def changeInterestForRead(): Boolean - - // On receiving a write event, should we change the interest for this channel or not ? - // Will be false for ReceivingConnection, true for SendingConnection. - // Actually, for now, should not get triggered for ReceivingConnection - def changeInterestForWrite(): Boolean - - def getRemoteConnectionManagerId(): ConnectionManagerId = { - socketRemoteConnectionManagerId - } - - def key() = channel.keyFor(selector) - - def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - - // Returns whether we have to register for further reads or not. - def read(): Boolean = { - throw new UnsupportedOperationException( - "Cannot read on connection of type " + this.getClass.toString) - } - - // Returns whether we have to register for further writes or not. - def write(): Boolean = { - throw new UnsupportedOperationException( - "Cannot write on connection of type " + this.getClass.toString) - } - - def close() { - closed = true - val k = key() - if (k != null) { - k.cancel() - } - channel.close() - callOnCloseCallback() - } - - protected def isClosed: Boolean = closed - - def onClose(callback: Connection => Unit) { - onCloseCallback = callback - } - - def onException(callback: (Connection, Exception) => Unit) { - onExceptionCallback = callback - } - - def onKeyInterestChange(callback: (Connection, Int) => Unit) { - onKeyInterestChangeCallback = callback - } - - def callOnExceptionCallback(e: Exception) { - if (onExceptionCallback != null) { - onExceptionCallback(this, e) - } else { - logError("Error in connection to " + getRemoteConnectionManagerId() + - " and OnExceptionCallback not registered", e) - } - } - - def callOnCloseCallback() { - if (onCloseCallback != null) { - onCloseCallback(this) - } else { - logWarning("Connection to " + getRemoteConnectionManagerId() + - " closed and OnExceptionCallback not registered") - } - - } - - def changeConnectionKeyInterest(ops: Int) { - if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) - } else { - throw new Exception("OnKeyInterestChangeCallback not registered") - } - } - - def printRemainingBuffer(buffer: ByteBuffer) { - val bytes = new Array[Byte](buffer.remaining) - val curPosition = buffer.position - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - buffer.position(curPosition) - print(" (" + bytes.size + ")") - } - - def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { - val bytes = new Array[Byte](length) - val curPosition = buffer.position - buffer.position(position) - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - print(" (" + position + ", " + length + ")") - buffer.position(curPosition) - } -} - - -private[spark] -class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId) - extends Connection(SocketChannel.open, selector_, remoteId_) { - - private class Outbox(fair: Int = 0) { - val messages = new Queue[Message]() - val defaultChunkSize = 65536 //32768 //16384 - var nextMessageToBeUsed = 0 - - def addMessage(message: Message) { - messages.synchronized{ - /*messages += message*/ - messages.enqueue(message) - logDebug("Added [" + message + "] to outbox for sending to " + - "[" + getRemoteConnectionManagerId() + "]") - } - } - - def getChunk(): Option[MessageChunk] = { - fair match { - case 0 => getChunkFIFO() - case 1 => getChunkRR() - case _ => throw new Exception("Unexpected fairness policy in outbox") - } - } - - private def getChunkFIFO(): Option[MessageChunk] = { - /*logInfo("Using FIFO")*/ - messages.synchronized { - while (!messages.isEmpty) { - val message = messages(0) - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages += message // this is probably incorrect, it wont work as fifo - if (!message.started) { - logDebug("Starting to send [" + message + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - return chunk - } else { - /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/ - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + - "] in " + message.timeTaken ) - } - } - } - None - } - - private def getChunkRR(): Option[MessageChunk] = { - messages.synchronized { - while (!messages.isEmpty) { - /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ - /*val message = messages(nextMessageToBeUsed)*/ - val message = messages.dequeue - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages.enqueue(message) - nextMessageToBeUsed = nextMessageToBeUsed + 1 - if (!message.started) { - logDebug( - "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - logTrace( - "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") - return chunk - } else { - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + - "] in " + message.timeTaken ) - } - } - } - None - } - } - - // outbox is used as a lock - ensure that it is always used as a leaf (since methods which - // lock it are invoked in context of other locks) - private val outbox = new Outbox(1) - /* - This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly - different purpose. This flag is to see if we need to force reregister for write even when we - do not have any pending bytes to write to socket. - This can happen due to a race between adding pending buffers, and checking for existing of - data as detailed in https://github.com/mesos/spark/pull/791 - */ - private var needForceReregister = false - val currentBuffers = new ArrayBuffer[ByteBuffer]() - - /*channel.socket.setSendBufferSize(256 * 1024)*/ - - override def getRemoteAddress() = address - - val DEFAULT_INTEREST = SelectionKey.OP_READ - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(DEFAULT_INTEREST) - } - - def send(message: Message) { - outbox.synchronized { - outbox.addMessage(message) - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - // return previous value after resetting it. - def resetForceReregister(): Boolean = { - outbox.synchronized { - val result = needForceReregister - needForceReregister = false - result - } - } - - // MUST be called within the selector loop - def connect() { - try{ - channel.register(selector, SelectionKey.OP_CONNECT) - channel.connect(address) - logInfo("Initiating connection to [" + address + "]") - } catch { - case e: Exception => { - logError("Error connecting to " + address, e) - callOnExceptionCallback(e) - } - } - } - - def finishConnect(force: Boolean): Boolean = { - try { - // Typically, this should finish immediately since it was triggered by a connect - // selection - though need not necessarily always complete successfully. - val connected = channel.finishConnect - if (!force && !connected) { - logInfo( - "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") - return false - } - - // Fallback to previous behavior - assume finishConnect completed - // This will happen only when finishConnect failed for some repeated number of times - // (10 or so) - // Is highly unlikely unless there was an unclean close of socket, etc - registerInterest() - logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - return true - } catch { - case e: Exception => { - logWarning("Error finishing connection to " + address, e) - callOnExceptionCallback(e) - // ignore - return true - } - } - } - - override def write(): Boolean = { - try { - while (true) { - if (currentBuffers.size == 0) { - outbox.synchronized { - outbox.getChunk() match { - case Some(chunk) => { - val buffers = chunk.buffers - // If we have 'seen' pending messages, then reset flag - since we handle that as normal - // registering of event (below) - if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() - currentBuffers ++= buffers - } - case None => { - // changeConnectionKeyInterest(0) - /*key.interestOps(0)*/ - return false - } - } - } - } - - if (currentBuffers.size > 0) { - val buffer = currentBuffers(0) - val remainingBytes = buffer.remaining - val writtenBytes = channel.write(buffer) - if (buffer.remaining == 0) { - currentBuffers -= buffer - } - if (writtenBytes < remainingBytes) { - // re-register for write. - return true - } - } - } - } catch { - case e: Exception => { - logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - return true - } - - // This is a hack to determine if remote socket was closed or not. - // SendingConnection DOES NOT expect to receive any data - if it does, it is an error - // For a bunch of cases, read will return -1 in case remote socket is closed : hence we - // register for reads to determine that. - override def read(): Boolean = { - // We don't expect the other side to send anything; so, we just read to detect an error or EOF. - try { - val length = channel.read(ByteBuffer.allocate(1)) - if (length == -1) { // EOF - close() - } else if (length > 0) { - logWarning( - "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) - } - } catch { - case e: Exception => - logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) - close() - } - - false - } - - override def changeInterestForRead(): Boolean = false - - override def changeInterestForWrite(): Boolean = ! isClosed -} - - -// Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) - extends Connection(channel_, selector_) { - - class Inbox() { - val messages = new HashMap[Int, BufferMessage]() - - def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - - def createNewMessage: BufferMessage = { - val newMessage = Message.create(header).asInstanceOf[BufferMessage] - newMessage.started = true - newMessage.startTime = System.currentTimeMillis - logDebug( - "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") - messages += ((newMessage.id, newMessage)) - newMessage - } - - val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace( - "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") - message.getChunkForReceiving(header.chunkSize) - } - - def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) - } - - def removeMessage(message: Message) { - messages -= message.id - } - } - - @volatile private var inferredRemoteManagerId: ConnectionManagerId = null - - override def getRemoteConnectionManagerId(): ConnectionManagerId = { - val currId = inferredRemoteManagerId - if (currId != null) currId else super.getRemoteConnectionManagerId() - } - - // The reciever's remote address is the local socket on remote side : which is NOT - // the connection manager id of the receiver. - // We infer that from the messages we receive on the receiver socket. - private def processConnectionManagerId(header: MessageChunkHeader) { - val currId = inferredRemoteManagerId - if (header.address == null || currId != null) return - - val managerId = ConnectionManagerId.fromSocketAddress(header.address) - - if (managerId != null) { - inferredRemoteManagerId = managerId - } - } - - - val inbox = new Inbox() - val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection , Message) => Unit = null - var currentChunk: MessageChunk = null - - channel.register(selector, SelectionKey.OP_READ) - - override def read(): Boolean = { - try { - while (true) { - if (currentChunk == null) { - val headerBytesRead = channel.read(headerBuffer) - if (headerBytesRead == -1) { - close() - return false - } - if (headerBuffer.remaining > 0) { - // re-register for read event ... - return true - } - headerBuffer.flip - if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception( - "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") - } - val header = MessageChunkHeader.create(headerBuffer) - headerBuffer.clear() - - processConnectionManagerId(header) - - header.typ match { - case Message.BUFFER_MESSAGE => { - if (header.totalSize == 0) { - if (onReceiveCallback != null) { - onReceiveCallback(this, Message.create(header)) - } - currentChunk = null - // re-register for read event ... - return true - } else { - currentChunk = inbox.getChunk(header).orNull - } - } - case _ => throw new Exception("Message of unknown type received") - } - } - - if (currentChunk == null) throw new Exception("No message chunk to receive data") - - val bytesRead = channel.read(currentChunk.buffer) - if (bytesRead == 0) { - // re-register for read event ... - return true - } else if (bytesRead == -1) { - close() - return false - } - - /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ - - if (currentChunk.buffer.remaining == 0) { - /*println("Filled buffer at " + System.currentTimeMillis)*/ - val bufferMessage = inbox.getMessageForChunk(currentChunk).get - if (bufferMessage.isCompletelyReceived) { - bufferMessage.flip - bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from " + - "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) - if (onReceiveCallback != null) { - onReceiveCallback(this, bufferMessage) - } - inbox.removeMessage(bufferMessage) - } - currentChunk = null - } - } - } catch { - case e: Exception => { - logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - return true - } - - def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} - - // override def changeInterestForRead(): Boolean = ! isClosed - override def changeInterestForRead(): Boolean = true - - override def changeInterestForWrite(): Boolean = { - throw new IllegalStateException("Unexpected invocation right now") - } - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_READ) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(0) - } - - // For read conn, always false. - override def resetForceReregister(): Boolean = false -} diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala deleted file mode 100644 index 8b9f3ae18c..0000000000 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ /dev/null @@ -1,720 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import spark._ - -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.net._ -import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} - -import scala.collection.mutable.HashSet -import scala.collection.mutable.HashMap -import scala.collection.mutable.SynchronizedMap -import scala.collection.mutable.SynchronizedQueue -import scala.collection.mutable.ArrayBuffer - -import akka.dispatch.{Await, Promise, ExecutionContext, Future} -import akka.util.Duration -import akka.util.duration._ - - -private[spark] class ConnectionManager(port: Int) extends Logging { - - class MessageStatus( - val message: Message, - val connectionManagerId: ConnectionManagerId, - completionHandler: MessageStatus => Unit) { - - var ackMessage: Option[Message] = None - var attempted = false - var acked = false - - def markDone() { completionHandler(this) } - } - - private val selector = SelectorProvider.provider.openSelector() - - private val handleMessageExecutor = new ThreadPoolExecutor( - System.getProperty("spark.core.connection.handler.threads.min","20").toInt, - System.getProperty("spark.core.connection.handler.threads.max","60").toInt, - System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable]()) - - private val handleReadWriteExecutor = new ThreadPoolExecutor( - System.getProperty("spark.core.connection.io.threads.min","4").toInt, - System.getProperty("spark.core.connection.io.threads.max","32").toInt, - System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable]()) - - // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap - private val handleConnectExecutor = new ThreadPoolExecutor( - System.getProperty("spark.core.connection.connect.threads.min","1").toInt, - System.getProperty("spark.core.connection.connect.threads.max","8").toInt, - System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable]()) - - private val serverChannel = ServerSocketChannel.open() - private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - private val messageStatuses = new HashMap[Int, MessageStatus] - private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - private val registerRequests = new SynchronizedQueue[SendingConnection] - - implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) - - private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null - - serverChannel.configureBlocking(false) - serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) - - serverChannel.socket.bind(new InetSocketAddress(port)) - serverChannel.register(selector, SelectionKey.OP_ACCEPT) - - val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) - logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - private val selectorThread = new Thread("connection-manager-thread") { - override def run() = ConnectionManager.this.run() - } - selectorThread.setDaemon(true) - selectorThread.start() - - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - - private def triggerWrite(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - writeRunnableStarted.synchronized { - // So that we do not trigger more write events while processing this one. - // The write method will re-register when done. - if (conn.changeInterestForWrite()) conn.unregisterInterest() - if (writeRunnableStarted.contains(key)) { - // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) - return - } - - writeRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - var register: Boolean = false - try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() - } - } - } - } - } ) - } - - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - - private def triggerRead(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - readRunnableStarted.synchronized { - // So that we do not trigger more read events while processing this one. - // The read method will re-register when done. - if (conn.changeInterestForRead())conn.unregisterInterest() - if (readRunnableStarted.contains(key)) { - return - } - - readRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - var register: Boolean = false - try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() - } - } - } - } - } ) - } - - private def triggerConnect(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] - if (conn == null) return - - // prevent other events from being triggered - // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite - conn.changeConnectionKeyInterest(0) - - handleConnectExecutor.execute(new Runnable { - override def run() { - - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 - } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need not - // succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) - } - } ) - } - - // MUST be called within selector loop - else deadlock. - private def triggerForceCloseByException(key: SelectionKey, e: Exception) { - try { - key.interestOps(0) - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - // Pushing to connect threadpool - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - conn.callOnExceptionCallback(e) - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - try { - conn.close() - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - } - }) - } - - - def run() { - try { - while(!selectorThread.isInterrupted) { - while (! registerRequests.isEmpty) { - val conn: SendingConnection = registerRequests.dequeue - addListeners(conn) - conn.connect() - addConnection(conn) - } - - while(!keyInterestChangeRequests.isEmpty) { - val (key, ops) = keyInterestChangeRequests.dequeue - - try { - if (key.isValid) { - val connection = connectionsByKey.getOrElse(key, null) - if (connection != null) { - val lastOps = key.interestOps() - key.interestOps(ops) - - // hot loop - prevent materialization of string if trace not enabled. - if (isTraceEnabled()) { - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() + - "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - - val selectedKeysCount = - try { - selector.select() - } catch { - // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently. - case e: CancelledKeyException => { - // Some keys within the selectors list are invalid/closed. clear them. - val allKeys = selector.keys().iterator() - - while (allKeys.hasNext()) { - val key = allKeys.next() - try { - if (! key.isValid) { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - } - 0 - } - - if (selectedKeysCount == 0) { - logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") - } - if (selectorThread.isInterrupted) { - logInfo("Selector thread was interrupted!") - return - } - - if (0 != selectedKeysCount) { - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext()) { - val key = selectedKeys.next - selectedKeys.remove() - try { - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else - if (key.isConnectable) { - triggerConnect(key) - } else - if (key.isReadable) { - triggerRead(key) - } else - if (key.isWritable) { - triggerWrite(key) - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException. - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - } - } - } catch { - case e: Exception => logError("Error in select loop", e) - } - } - - def acceptConnection(key: SelectionKey) { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - - var newChannel = serverChannel.accept() - - // accept them all in a tight loop. non blocking accept with no processing, should be fine - while (newChannel != null) { - try { - val newConnection = new ReceivingConnection(newChannel, selector) - newConnection.onReceive(receiveMessage) - addListeners(newConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") - } catch { - // might happen in case of issues with registering with selector - case e: Exception => logError("Error in accept loop", e) - } - - newChannel = serverChannel.accept() - } - } - - private def addListeners(connection: Connection) { - connection.onKeyInterestChange(changeConnectionKeyInterest) - connection.onException(handleConnectionError) - connection.onClose(removeConnection) - } - - def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - } - - def removeConnection(connection: Connection) { - connectionsByKey -= connection.key - - try { - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - messageStatuses - .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { - logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } - }) - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } else if (connection.isInstanceOf[ReceivingConnection]) { - val receivingConnection = connection.asInstanceOf[ReceivingConnection] - val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) - if (! sendingConnectionOpt.isDefined) { - logError("Corresponding SendingConnectionManagerId not found") - return - } - - val sendingConnection = sendingConnectionOpt.get - connectionsById -= remoteConnectionManagerId - sendingConnection.close() - - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - - assert (sendingConnectionManagerId == remoteConnectionManagerId) - - messageStatuses.synchronized { - for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() - } - } - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } - } finally { - // So that the selection keys can be removed. - wakeupSelector() - } - } - - def handleConnectionError(connection: Connection, e: Exception) { - logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) - removeConnection(connection) - } - - def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) - // so that registerations happen ! - wakeupSelector() - } - - def receiveMessage(connection: Connection, message: Message) { - val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") - val runnable = new Runnable() { - val creationTime = System.currentTimeMillis - def run() { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") - } - } - handleMessageExecutor.execute(runnable) - /*handleMessage(connection, message)*/ - } - - private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { - logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") - message match { - case bufferMessage: BufferMessage => { - if (bufferMessage.hasAckId) { - val sentMessageStatus = messageStatuses.synchronized { - messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId - status - } - case None => { - throw new Exception("Could not find reference for received ack message " + message.id) - null - } - } - } - sentMessageStatus.synchronized { - sentMessageStatus.ackMessage = Some(message) - sentMessageStatus.attempted = true - sentMessageStatus.acked = true - sentMessageStatus.markDone() - } - } else { - val ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } - - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id - } - } - - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) - } - } - case _ => throw new Exception("Unknown type message received") - } - } - - private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) - registerRequests.enqueue(newConnection) - - newConnection - } - // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ... - // If we do re-add it, we should consistently use it everywhere I guess ? - val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) - message.senderAddress = id.toSocketAddress() - logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - connection.send(message) - - wakeupSelector() - } - - private def wakeupSelector() { - selector.wakeup() - } - - def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Option[Message]] = { - val promise = Promise[Option[Message]] - val status = new MessageStatus(message, connectionManagerId, s => promise.success(s.ackMessage)) - messageStatuses.synchronized { - messageStatuses += ((message.id, status)) - } - sendMessage(connectionManagerId, message) - promise.future - } - - def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = { - Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) - } - - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { - onReceiveCallback = callback - } - - def stop() { - selectorThread.interrupt() - selectorThread.join() - selector.close() - val connections = connectionsByKey.values - connections.foreach(_.close()) - if (connectionsByKey.size != 0) { - logWarning("All connections not cleaned up") - } - handleMessageExecutor.shutdown() - handleReadWriteExecutor.shutdown() - handleConnectExecutor.shutdown() - logInfo("ConnectionManager stopped") - } -} - - -private[spark] object ConnectionManager { - - def main(args: Array[String]) { - val manager = new ConnectionManager(9999) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - /*testSequentialSending(manager)*/ - /*System.gc()*/ - - /*testParallelSending(manager)*/ - /*System.gc()*/ - - /*testParallelDecreasingSending(manager)*/ - /*System.gc()*/ - - testContinuousSending(manager) - System.gc() - } - - def testSequentialSending(manager: ConnectionManager) { - println("--------------------------") - println("Sequential Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) - }) - println("--------------------------") - println() - } - - def testParallelSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") - }) - val finishTime = System.currentTimeMillis - - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) - println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testParallelDecreasingSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Decreasing Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte))) - buffers.foreach(_.flip) - val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") - }) - val finishTime = System.currentTimeMillis - - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - /*println("Started at " + startTime + ", finished at " + finishTime) */ - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testContinuousSending(manager: ConnectionManager) { - println("--------------------------") - println("Continuous Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - while(true) { - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") - }) - val finishTime = System.currentTimeMillis - Thread.sleep(1000) - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - } -} diff --git a/core/src/main/scala/spark/network/ConnectionManagerId.scala b/core/src/main/scala/spark/network/ConnectionManagerId.scala deleted file mode 100644 index 9d5c518293..0000000000 --- a/core/src/main/scala/spark/network/ConnectionManagerId.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.net.InetSocketAddress - -import spark.Utils - - -private[spark] case class ConnectionManagerId(host: String, port: Int) { - // DEBUG code - Utils.checkHost(host) - assert (port > 0) - - def toSocketAddress() = new InetSocketAddress(host, port) -} - - -private[spark] object ConnectionManagerId { - def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) - } -} diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala deleted file mode 100644 index 9e3827aaf5..0000000000 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import spark._ -import spark.SparkContext._ - -import scala.io.Source - -import java.nio.ByteBuffer -import java.net.InetAddress - -import akka.dispatch.Await -import akka.util.duration._ - -private[spark] object ConnectionManagerTest extends Logging{ - def main(args: Array[String]) { - // - the master URL - // - a list slaves to run connectionTest on - //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts - //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10 - //[count] - how many times to run, default is 3 - //[await time in seconds] : await time (in seconds), default is 600 - if (args.length < 2) { - println("Usage: ConnectionManagerTest [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ") - System.exit(1) - } - - if (args(0).startsWith("local")) { - println("This runs only on a mesos cluster") - } - - val sc = new SparkContext(args(0), "ConnectionManagerTest") - val slavesFile = Source.fromFile(args(1)) - val slaves = slavesFile.mkString.split("\n") - slavesFile.close() - - /*println("Slaves")*/ - /*slaves.foreach(println)*/ - val tasknum = if (args.length > 2) args(2).toInt else slaves.length - val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 - val count = if (args.length > 4) args(4).toInt else 3 - val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second - println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) - val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( - i => SparkEnv.get.connectionManager.id).collect() - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - logInfo("Received [" + msg + "] from [" + id + "]") - None - }) - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - }) - val results = futures.map(f => Await.result(f, awaitTime)) - val finishTime = System.currentTimeMillis - Thread.sleep(5000) - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" - logInfo(resultStr) - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } -} - diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala deleted file mode 100644 index a25457ea35..0000000000 --- a/core/src/main/scala/spark/network/Message.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.nio.ByteBuffer -import java.net.InetSocketAddress - -import scala.collection.mutable.ArrayBuffer - - -private[spark] abstract class Message(val typ: Long, val id: Int) { - var senderAddress: InetSocketAddress = null - var started = false - var startTime = -1L - var finishTime = -1L - - def size: Int - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - - def timeTaken(): String = (finishTime - startTime).toString + " ms" - - override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" -} - - -private[spark] object Message { - val BUFFER_MESSAGE = 1111111111L - - var lastId = 1 - - def getNewId() = synchronized { - lastId += 1 - if (lastId == 0) { - lastId += 1 - } - lastId - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { - if (dataBuffers == null) { - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } - if (dataBuffers.exists(_ == null)) { - throw new Exception("Attempting to create buffer message with null buffer") - } - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = - createBufferMessage(dataBuffers, 0) - - def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { - if (dataBuffer == null) { - return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) - } else { - return createBufferMessage(Array(dataBuffer), ackId) - } - } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = - createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = { - createBufferMessage(new Array[ByteBuffer](0), ackId) - } - - def create(header: MessageChunkHeader): Message = { - val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, - ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) - } - newMessage.senderAddress = header.address - newMessage - } -} diff --git a/core/src/main/scala/spark/network/MessageChunk.scala b/core/src/main/scala/spark/network/MessageChunk.scala deleted file mode 100644 index 784db5ab62..0000000000 --- a/core/src/main/scala/spark/network/MessageChunk.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - - -private[network] -class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - - val size = if (buffer == null) 0 else buffer.remaining - - lazy val buffers = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } - - override def toString = { - "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" - } -} diff --git a/core/src/main/scala/spark/network/MessageChunkHeader.scala b/core/src/main/scala/spark/network/MessageChunkHeader.scala deleted file mode 100644 index 18d0cbcc14..0000000000 --- a/core/src/main/scala/spark/network/MessageChunkHeader.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.net.InetAddress -import java.net.InetSocketAddress -import java.nio.ByteBuffer - - -private[spark] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. - // Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" -} - - -private[spark] object MessageChunkHeader { - val HEADER_SIZE = 40 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) - } -} diff --git a/core/src/main/scala/spark/network/ReceiverTest.scala b/core/src/main/scala/spark/network/ReceiverTest.scala deleted file mode 100644 index 2bbc736f40..0000000000 --- a/core/src/main/scala/spark/network/ReceiverTest.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.nio.ByteBuffer -import java.net.InetAddress - -private[spark] object ReceiverTest { - - def main(args: Array[String]) { - val manager = new ConnectionManager(9999) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/ - val buffer = ByteBuffer.wrap("response".getBytes()) - Some(Message.createBufferMessage(buffer, msg.id)) - }) - Thread.currentThread.join() - } -} - diff --git a/core/src/main/scala/spark/network/SenderTest.scala b/core/src/main/scala/spark/network/SenderTest.scala deleted file mode 100644 index 542c54c36b..0000000000 --- a/core/src/main/scala/spark/network/SenderTest.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.nio.ByteBuffer -import java.net.InetAddress - -private[spark] object SenderTest { - - def main(args: Array[String]) { - - if (args.length < 2) { - println("Usage: SenderTest ") - System.exit(1) - } - - val targetHost = args(0) - val targetPort = args(1).toInt - val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - - val manager = new ConnectionManager(0) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - val size = 100 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val targetServer = args(0) - - val count = 100 - (0 until count).foreach(i => { - val dataMessage = Message.createBufferMessage(buffer.duplicate) - val startTime = System.currentTimeMillis - /*println("Started timer at " + startTime)*/ - val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match { - case Some(response) => - val buffer = response.asInstanceOf[BufferMessage].buffers(0) - new String(buffer.array) - case None => "none" - } - val finishTime = System.currentTimeMillis - val mb = size / 1024.0 / 1024.0 - val ms = finishTime - startTime - /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/ - val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr - println(resultStr) - }) - } -} - diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala deleted file mode 100644 index bf46d32aa3..0000000000 --- a/core/src/main/scala/spark/network/netty/FileHeader.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty - -import io.netty.buffer._ - -import spark.Logging - -private[spark] class FileHeader ( - val fileLen: Int, - val blockId: String) extends Logging { - - lazy val buffer = { - val buf = Unpooled.buffer() - buf.capacity(FileHeader.HEADER_SIZE) - buf.writeInt(fileLen) - buf.writeInt(blockId.length) - blockId.foreach((x: Char) => buf.writeByte(x)) - //padding the rest of header - if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { - buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) - } else { - throw new Exception("too long header " + buf.readableBytes) - logInfo("too long header") - } - buf - } - -} - -private[spark] object FileHeader { - - val HEADER_SIZE = 40 - - def getFileLenOffset = 0 - def getFileLenSize = Integer.SIZE/8 - - def create(buf: ByteBuf): FileHeader = { - val length = buf.readInt - val idLength = buf.readInt - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buf.readByte().asInstanceOf[Char] - } - val blockId = idBuilder.toString() - new FileHeader(length, blockId) - } - - - def main (args:Array[String]){ - - val header = new FileHeader(25,"block_0"); - val buf = header.buffer; - val newheader = FileHeader.create(buf); - System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) - - } -} - diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala deleted file mode 100644 index b01f6369f6..0000000000 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty - -import java.util.concurrent.Executors - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.util.CharsetUtil - -import spark.Logging -import spark.network.ConnectionManagerId - -import scala.collection.JavaConverters._ - - -private[spark] class ShuffleCopier extends Logging { - - def getBlock(host: String, port: Int, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { - - val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt - val fc = new FileClient(handler, connectTimeout) - - try { - fc.init() - fc.connect(host, port) - fc.sendRequest(blockId) - fc.waitForClose() - fc.close() - } catch { - // Handle any socket-related exceptions in FileClient - case e: Exception => { - logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) - handler.handleError(blockId) - } - } - } - - def getBlock(cmId: ConnectionManagerId, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { - getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) - } - - def getBlocks(cmId: ConnectionManagerId, - blocks: Seq[(String, Long)], - resultCollectCallback: (String, Long, ByteBuf) => Unit) { - - for ((blockId, size) <- blocks) { - getBlock(cmId, blockId, resultCollectCallback) - } - } -} - - -private[spark] object ShuffleCopier extends Logging { - - private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) - extends FileClientHandler with Logging { - - override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); - resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) - } - - override def handleError(blockId: String) { - if (!isComplete) { - resultCollectCallBack(blockId, -1, null) - } - } - } - - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { - if (size != -1) { - logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: ShuffleCopier ") - System.exit(1) - } - val host = args(0) - val port = args(1).toInt - val file = args(2) - val threads = if (args.length > 3) args(3).toInt else 10 - - val copiers = Executors.newFixedThreadPool(80) - val tasks = (for (i <- Range(0, threads)) yield { - Executors.callable(new Runnable() { - def run() { - val copier = new ShuffleCopier() - copier.getBlock(host, port, file, echoResultCollectCallBack) - } - }) - }).asJava - copiers.invokeAll(tasks) - copiers.shutdown - System.exit(0) - } -} diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala deleted file mode 100644 index cdf88b03a0..0000000000 --- a/core/src/main/scala/spark/network/netty/ShuffleSender.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty - -import java.io.File - -import spark.Logging - - -private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { - - val server = new FileServer(pResolver, portIn) - server.start() - - def stop() { - server.stop() - } - - def port: Int = server.getPort() -} - - -/** - * An application for testing the shuffle sender as a standalone program. - */ -private[spark] object ShuffleSender { - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ShuffleSender ") - System.exit(1) - } - - val port = args(0).toInt - val subDirsPerLocalDir = args(1).toInt - val localDirs = args.drop(2).map(new File(_)) - - val pResovler = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!blockId.startsWith("shuffle_")) { - throw new Exception("Block " + blockId + " is not a shuffle block") - } - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = math.abs(blockId.hashCode) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) - val file = new File(subDir, blockId) - return file.getAbsolutePath - } - } - val sender = new ShuffleSender(port, pResovler) - } -} diff --git a/core/src/main/scala/spark/package.scala b/core/src/main/scala/spark/package.scala deleted file mode 100644 index b244bfbf06..0000000000 --- a/core/src/main/scala/spark/package.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Core Spark functionality. [[spark.SparkContext]] serves as the main entry point to Spark, while - * [[spark.RDD]] is the data type representing a distributed collection, and provides most - * parallel operations. - * - * In addition, [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value - * pairs, such as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations - * available only on RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations - * available on RDDs that can be saved as SequenceFiles. These operations are automatically - * available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit conversions when - * you `import spark.SparkContext._`. - */ -package object spark { - // For package docs only -} diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala deleted file mode 100644 index 691d939150..0000000000 --- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import spark._ -import spark.scheduler.JobListener - -/** - * A JobListener for an approximate single-result action, such as count() or non-parallel reduce(). - * This listener waits up to timeout milliseconds and will return a partial answer even if the - * complete answer is not available by then. - * - * This class assumes that the action is performed on an entire RDD[T] via a function that computes - * a result of type U for each partition, and that the action returns a partial or complete result - * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt). - */ -private[spark] class ApproximateActionListener[T, U, R]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - evaluator: ApproximateEvaluator[U, R], - timeout: Long) - extends JobListener { - - val startTime = System.currentTimeMillis() - val totalTasks = rdd.partitions.size - var finishedTasks = 0 - var failure: Option[Exception] = None // Set if the job has failed (permanently) - var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult - - override def taskSucceeded(index: Int, result: Any) { - synchronized { - evaluator.merge(index, result.asInstanceOf[U]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - // If we had already returned a PartialResult, set its final value - resultObject.foreach(r => r.setFinalValue(evaluator.currentResult())) - // Notify any waiting thread that may have called awaitResult - this.notifyAll() - } - } - } - - override def jobFailed(exception: Exception) { - synchronized { - failure = Some(exception) - this.notifyAll() - } - } - - /** - * Waits for up to timeout milliseconds since the listener was created and then returns a - * PartialResult with the result so far. This may be complete if the whole job is done. - */ - def awaitResult(): PartialResult[R] = synchronized { - val finishTime = startTime + timeout - while (true) { - val time = System.currentTimeMillis() - if (failure != None) { - throw failure.get - } else if (finishedTasks == totalTasks) { - return new PartialResult(evaluator.currentResult(), true) - } else if (time >= finishTime) { - resultObject = Some(new PartialResult(evaluator.currentResult(), false)) - return resultObject.get - } else { - this.wait(finishTime - time) - } - } - // Should never be reached, but required to keep the compiler happy - return null - } -} diff --git a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala deleted file mode 100644 index 5eae144dfb..0000000000 --- a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -/** - * An object that computes a function incrementally by merging in results of type U from multiple - * tasks. Allows partial evaluation at any point by calling currentResult(). - */ -private[spark] trait ApproximateEvaluator[U, R] { - def merge(outputId: Int, taskResult: U): Unit - def currentResult(): R -} diff --git a/core/src/main/scala/spark/partial/BoundedDouble.scala b/core/src/main/scala/spark/partial/BoundedDouble.scala deleted file mode 100644 index 8bdbe6c012..0000000000 --- a/core/src/main/scala/spark/partial/BoundedDouble.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -/** - * A Double with error bars on it. - */ -class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { - override def toString(): String = "[%.3f, %.3f]".format(low, high) -} diff --git a/core/src/main/scala/spark/partial/CountEvaluator.scala b/core/src/main/scala/spark/partial/CountEvaluator.scala deleted file mode 100644 index 6aa92094eb..0000000000 --- a/core/src/main/scala/spark/partial/CountEvaluator.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import cern.jet.stat.Probability - -/** - * An ApproximateEvaluator for counts. - * - * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might - * be best to make this a special case of GroupedCountEvaluator with one group. - */ -private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[Long, BoundedDouble] { - - var outputsMerged = 0 - var sum: Long = 0 - - override def merge(outputId: Int, taskResult: Long) { - outputsMerged += 1 - sum += taskResult - } - - override def currentResult(): BoundedDouble = { - if (outputsMerged == totalOutputs) { - new BoundedDouble(sum, 1.0, sum, sum) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) - } else { - val p = outputsMerged.toDouble / totalOutputs - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - new BoundedDouble(mean, confidence, low, high) - } - } -} diff --git a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala deleted file mode 100644 index ebe2e5a1e3..0000000000 --- a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import java.util.{HashMap => JHashMap} -import java.util.{Map => JMap} - -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions.mapAsScalaMap - -import cern.jet.stat.Probability - -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -/** - * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new OLMap[T] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: OLMap[T]) { - outputsMerged += 1 - val iter = taskResult.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue) - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getLongValue() - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) - } - result - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getLongValue - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) - } - result - } - } -} diff --git a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala deleted file mode 100644 index 2dadbbd5fb..0000000000 --- a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import java.util.{HashMap => JHashMap} -import java.util.{Map => JMap} - -import scala.collection.mutable.HashMap -import scala.collection.Map -import scala.collection.JavaConversions.mapAsScalaMap - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val mean = entry.getValue.mean - result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean) - } - result - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val mean = counter.mean - val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = studentTCacher.get(counter.count) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) - } - result - } - } -} diff --git a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala deleted file mode 100644 index ae2b63f7cb..0000000000 --- a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import java.util.{HashMap => JHashMap} -import java.util.{Map => JMap} - -import scala.collection.mutable.HashMap -import scala.collection.Map -import scala.collection.JavaConversions.mapAsScalaMap - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getValue.sum - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) - } - result - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count - val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) - val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = studentTCacher.get(counter.count) - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high) - } - result - } - } -} diff --git a/core/src/main/scala/spark/partial/MeanEvaluator.scala b/core/src/main/scala/spark/partial/MeanEvaluator.scala deleted file mode 100644 index 5ddcad7075..0000000000 --- a/core/src/main/scala/spark/partial/MeanEvaluator.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import cern.jet.stat.Probability - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for means. - */ -private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[StatCounter, BoundedDouble] { - - var outputsMerged = 0 - var counter = new StatCounter - - override def merge(outputId: Int, taskResult: StatCounter) { - outputsMerged += 1 - counter.merge(taskResult) - } - - override def currentResult(): BoundedDouble = { - if (outputsMerged == totalOutputs) { - new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) - } else { - val mean = counter.mean - val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = { - if (counter.count > 100) { - Probability.normalInverse(1 - (1 - confidence) / 2) - } else { - Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) - } - } - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - new BoundedDouble(mean, confidence, low, high) - } - } -} diff --git a/core/src/main/scala/spark/partial/PartialResult.scala b/core/src/main/scala/spark/partial/PartialResult.scala deleted file mode 100644 index 922a9f9bc6..0000000000 --- a/core/src/main/scala/spark/partial/PartialResult.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -class PartialResult[R](initialVal: R, isFinal: Boolean) { - private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None - private var failure: Option[Exception] = None - private var completionHandler: Option[R => Unit] = None - private var failureHandler: Option[Exception => Unit] = None - - def initialValue: R = initialVal - - def isInitialValueFinal: Boolean = isFinal - - /** - * Blocking method to wait for and return the final value. - */ - def getFinalValue(): R = synchronized { - while (finalValue == None && failure == None) { - this.wait() - } - if (finalValue != None) { - return finalValue.get - } else { - throw failure.get - } - } - - /** - * Set a handler to be called when this PartialResult completes. Only one completion handler - * is supported per PartialResult. - */ - def onComplete(handler: R => Unit): PartialResult[R] = synchronized { - if (completionHandler != None) { - throw new UnsupportedOperationException("onComplete cannot be called twice") - } - completionHandler = Some(handler) - if (finalValue != None) { - // We already have a final value, so let's call the handler - handler(finalValue.get) - } - return this - } - - /** - * Set a handler to be called if this PartialResult's job fails. Only one failure handler - * is supported per PartialResult. - */ - def onFail(handler: Exception => Unit) { - synchronized { - if (failureHandler != None) { - throw new UnsupportedOperationException("onFail cannot be called twice") - } - failureHandler = Some(handler) - if (failure != None) { - // We already have a failure, so let's call the handler - handler(failure.get) - } - } - } - - /** - * Transform this PartialResult into a PartialResult of type T. - */ - def map[T](f: R => T) : PartialResult[T] = { - new PartialResult[T](f(initialVal), isFinal) { - override def getFinalValue() : T = synchronized { - f(PartialResult.this.getFinalValue()) - } - override def onComplete(handler: T => Unit): PartialResult[T] = synchronized { - PartialResult.this.onComplete(handler.compose(f)).map(f) - } - override def onFail(handler: Exception => Unit) { - synchronized { - PartialResult.this.onFail(handler) - } - } - override def toString : String = synchronized { - PartialResult.this.getFinalValueInternal() match { - case Some(value) => "(final: " + f(value) + ")" - case None => "(partial: " + initialValue + ")" - } - } - def getFinalValueInternal() = PartialResult.this.getFinalValueInternal().map(f) - } - } - - private[spark] def setFinalValue(value: R) { - synchronized { - if (finalValue != None) { - throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult") - } - finalValue = Some(value) - // Call the completion handler if it was set - completionHandler.foreach(h => h(value)) - // Notify any threads that may be calling getFinalValue() - this.notifyAll() - } - } - - private def getFinalValueInternal() = finalValue - - private[spark] def setFailure(exception: Exception) { - synchronized { - if (failure != None) { - throw new UnsupportedOperationException("setFailure called twice on a PartialResult") - } - failure = Some(exception) - // Call the failure handler if it was set - failureHandler.foreach(h => h(exception)) - // Notify any threads that may be calling getFinalValue() - this.notifyAll() - } - } - - override def toString: String = synchronized { - finalValue match { - case Some(value) => "(final: " + value + ")" - case None => "(partial: " + initialValue + ")" - } - } -} diff --git a/core/src/main/scala/spark/partial/StudentTCacher.scala b/core/src/main/scala/spark/partial/StudentTCacher.scala deleted file mode 100644 index f3bb987d46..0000000000 --- a/core/src/main/scala/spark/partial/StudentTCacher.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import cern.jet.stat.Probability - -/** - * A utility class for caching Student's T distribution values for a given confidence level - * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate - * confidence intervals for many keys. - */ -private[spark] class StudentTCacher(confidence: Double) { - val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation - val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2) - val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) - - def get(sampleSize: Long): Double = { - if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) { - normalApprox - } else { - val size = sampleSize.toInt - if (cache(size) < 0) { - cache(size) = Probability.studentTInverse(1 - confidence, size - 1) - } - cache(size) - } - } -} diff --git a/core/src/main/scala/spark/partial/SumEvaluator.scala b/core/src/main/scala/spark/partial/SumEvaluator.scala deleted file mode 100644 index 4083abef03..0000000000 --- a/core/src/main/scala/spark/partial/SumEvaluator.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import cern.jet.stat.Probability - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them - * together, then uses the formula for the variance of two independent random variables to get - * a variance for the result and compute a confidence interval. - */ -private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[StatCounter, BoundedDouble] { - - var outputsMerged = 0 - var counter = new StatCounter - - override def merge(outputId: Int, taskResult: StatCounter) { - outputsMerged += 1 - counter.merge(taskResult) - } - - override def currentResult(): BoundedDouble = { - if (outputsMerged == totalOutputs) { - new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) - } else { - val p = outputsMerged.toDouble / totalOutputs - val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count - val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) - val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = { - if (counter.count > 100) { - Probability.normalInverse(1 - (1 - confidence) / 2) - } else { - Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) - } - } - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - new BoundedDouble(sumEstimate, confidence, low, high) - } - } -} diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala deleted file mode 100644 index 03800584ae..0000000000 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} -import spark.storage.BlockManager - -private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { - val index = idx -} - -private[spark] -class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) - extends RDD[T](sc, Nil) { - - @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) - - override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => { - new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] - }).toArray - - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val blockManager = SparkEnv.get.blockManager - val blockId = split.asInstanceOf[BlockRDDPartition].blockId - blockManager.get(blockId) match { - case Some(block) => block.asInstanceOf[Iterator[T]] - case None => - throw new Exception("Could not compute split, block " + blockId + " not found") - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - locations_(split.asInstanceOf[BlockRDDPartition].blockId) - } -} - diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala deleted file mode 100644 index 91b3e69d6f..0000000000 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.io.{ObjectOutputStream, IOException} -import spark._ - - -private[spark] -class CartesianPartition( - idx: Int, - @transient rdd1: RDD[_], - @transient rdd2: RDD[_], - s1Index: Int, - s2Index: Int - ) extends Partition { - var s1 = rdd1.partitions(s1Index) - var s2 = rdd2.partitions(s2Index) - override val index: Int = idx - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent split at the time of task serialization - s1 = rdd1.partitions(s1Index) - s2 = rdd2.partitions(s2Index) - oos.defaultWriteObject() - } -} - -private[spark] -class CartesianRDD[T: ClassManifest, U:ClassManifest]( - sc: SparkContext, - var rdd1 : RDD[T], - var rdd2 : RDD[U]) - extends RDD[Pair[T, U]](sc, Nil) - with Serializable { - - val numPartitionsInRdd2 = rdd2.partitions.size - - override def getPartitions: Array[Partition] = { - // create the cross product split - val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) - for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { - val idx = s1.index * numPartitionsInRdd2 + s2.index - array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) - } - array - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - val currSplit = split.asInstanceOf[CartesianPartition] - (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct - } - - override def compute(split: Partition, context: TaskContext) = { - val currSplit = split.asInstanceOf[CartesianPartition] - for (x <- rdd1.iterator(currSplit.s1, context); - y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) - } - - override def getDependencies: Seq[Dependency[_]] = List( - new NarrowDependency(rdd1) { - def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2) - }, - new NarrowDependency(rdd2) { - def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2) - } - ) - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - } -} diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala deleted file mode 100644 index 1ad5fe6539..0000000000 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark._ -import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.io.{NullWritable, BytesWritable} -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.fs.Path -import java.io.{File, IOException, EOFException} -import java.text.NumberFormat - -private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} - -/** - * This RDD represents a RDD checkpoint file (similar to HadoopRDD). - */ -private[spark] -class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String) - extends RDD[T](sc, Nil) { - - @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) - - override def getPartitions: Array[Partition] = { - val cpath = new Path(checkpointPath) - val numPartitions = - // listStatus can throw exception if path does not exist. - if (fs.exists(cpath)) { - val dirContents = fs.listStatus(cpath) - val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted - val numPart = partitionFiles.size - if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || - ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { - throw new SparkException("Invalid checkpoint directory: " + checkpointPath) - } - numPart - } else 0 - - Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) - } - - checkpointData = Some(new RDDCheckpointData[T](this)) - checkpointData.get.cpFile = Some(checkpointPath) - - override def getPreferredLocations(split: Partition): Seq[String] = { - val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))) - val locations = fs.getFileBlockLocations(status, 0, status.getLen) - locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") - } - - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) - CheckpointRDD.readFromFile(file, context) - } - - override def checkpoint() { - // Do nothing. CheckpointRDD should not be checkpointed. - } -} - -private[spark] object CheckpointRDD extends Logging { - - def splitIdToFile(splitId: Int): String = { - "part-%05d".format(splitId) - } - - def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { - val env = SparkEnv.get - val outputDir = new Path(path) - val fs = outputDir.getFileSystem(env.hadoop.newConfiguration()) - - val finalOutputName = splitIdToFile(ctx.splitId) - val finalOutputPath = new Path(outputDir, finalOutputName) - val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) - - if (fs.exists(tempOutputPath)) { - throw new IOException("Checkpoint failed: temporary path " + - tempOutputPath + " already exists") - } - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - - val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) - } else { - // This is mainly for testing purpose - fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) - } - val serializer = env.serializer.newInstance() - val serializeStream = serializer.serializeStream(fileOutputStream) - serializeStream.writeAll(iterator) - serializeStream.close() - - if (!fs.rename(tempOutputPath, finalOutputPath)) { - if (!fs.exists(finalOutputPath)) { - logInfo("Deleting tempOutputPath " + tempOutputPath) - fs.delete(tempOutputPath, false) - throw new IOException("Checkpoint failed: failed to save output of task: " - + ctx.attemptId + " and final output path does not exist") - } else { - // Some other copy of this task must've finished before us and renamed it - logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") - fs.delete(tempOutputPath, false) - } - } - } - - def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { - val env = SparkEnv.get - val fs = path.getFileSystem(env.hadoop.newConfiguration()) - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val fileInputStream = fs.open(path, bufferSize) - val serializer = env.serializer.newInstance() - val deserializeStream = serializer.deserializeStream(fileInputStream) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => deserializeStream.close()) - - deserializeStream.asIterator.asInstanceOf[Iterator[T]] - } - - // Test whether CheckpointRDD generate expected number of partitions despite - // each split file having multiple blocks. This needs to be run on a - // cluster (mesos or standalone) using HDFS. - def main(args: Array[String]) { - import spark._ - - val Array(cluster, hdfsPath) = args - val env = SparkEnv.get - val sc = new SparkContext(cluster, "CheckpointRDD Test") - val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) - val path = new Path(hdfsPath, "temp") - val fs = path.getFileSystem(env.hadoop.newConfiguration()) - sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) - val cpRDD = new CheckpointRDD[Int](sc, path.toString) - assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") - assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") - fs.delete(path, true) - } -} diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala deleted file mode 100644 index 01b6c23dcc..0000000000 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.io.{ObjectOutputStream, IOException} -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConversions -import scala.collection.mutable.ArrayBuffer - -import spark.{Partition, Partitioner, RDD, SparkEnv, TaskContext} -import spark.{Dependency, OneToOneDependency, ShuffleDependency} - - -private[spark] sealed trait CoGroupSplitDep extends Serializable - -private[spark] case class NarrowCoGroupSplitDep( - rdd: RDD[_], - splitIndex: Int, - var split: Partition - ) extends CoGroupSplitDep { - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent split at the time of task serialization - split = rdd.partitions(splitIndex) - oos.defaultWriteObject() - } -} - -private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep - -private[spark] -class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) - extends Partition with Serializable { - override val index: Int = idx - override def hashCode(): Int = idx -} - - -/** - * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a - * tuple with the list of values for that key. - * - * @param rdds parent RDDs. - * @param part partitioner used to partition the shuffle output. - */ -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { - - private var serializerClass: String = null - - def setSerializer(cls: String): CoGroupedRDD[K] = { - serializerClass = cls - this - } - - override def getDependencies: Seq[Dependency[_]] = { - rdds.map { rdd: RDD[_ <: Product2[K, _]] => - if (rdd.partitioner == Some(part)) { - logDebug("Adding one-to-one dependency with " + rdd) - new OneToOneDependency(rdd) - } else { - logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[Any, Any](rdd, part, serializerClass) - } - } - } - - override def getPartitions: Array[Partition] = { - val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { - // Each CoGroupPartition will have a dependency per contributing RDD - array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => - // Assume each RDD contributed a single dependency, and get it - dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) - case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) - } - }.toArray) - } - array - } - - override val partitioner = Some(part) - - override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { - val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.size - // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) - val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] - - def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - val seq = map.get(k) - if (seq != null) { - seq - } else { - val seq = Array.fill(numRdds)(new ArrayBuffer[Any]) - map.put(k, seq) - seq - } - } - - val ser = SparkEnv.get.serializerManager.get(serializerClass) - for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { - // Read them from the parent - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv => - getSeq(kv._1)(depNum) += kv._2 - } - } - case ShuffleCoGroupSplitDep(shuffleId) => { - // Read map outputs of shuffle - val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach { - kv => getSeq(kv._1)(depNum) += kv._2 - } - } - } - JavaConversions.mapAsScalaMap(map).iterator - } - - override def clearDependencies() { - super.clearDependencies() - rdds = null - } -} diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala deleted file mode 100644 index e612d026b2..0000000000 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ /dev/null @@ -1,342 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark._ -import java.io.{ObjectOutputStream, IOException} -import scala.collection.mutable -import scala.Some -import scala.collection.mutable.ArrayBuffer - -/** - * Class that captures a coalesced RDD by essentially keeping track of parent partitions - * @param index of this coalesced partition - * @param rdd which it belongs to - * @param parentsIndices list of indices in the parent that have been coalesced into this partition - * @param preferredLocation the preferred location for this partition - */ -case class CoalescedRDDPartition( - index: Int, - @transient rdd: RDD[_], - parentsIndices: Array[Int], - @transient preferredLocation: String = "" - ) extends Partition { - var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_)) - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent partition at the time of task serialization - parents = parentsIndices.map(rdd.partitions(_)) - oos.defaultWriteObject() - } - - /** - * Computes how many of the parents partitions have getPreferredLocation - * as one of their preferredLocations - * @return locality of this coalesced partition between 0 and 1 - */ - def localFraction: Double = { - val loc = parents.count(p => - rdd.context.getPreferredLocs(rdd, p.index).map(tl => tl.host).contains(preferredLocation)) - - if (parents.size == 0) 0.0 else (loc.toDouble / parents.size.toDouble) - } -} - -/** - * Represents a coalesced RDD that has fewer partitions than its parent RDD - * This class uses the PartitionCoalescer class to find a good partitioning of the parent RDD - * so that each new partition has roughly the same number of parent partitions and that - * the preferred location of each new partition overlaps with as many preferred locations of its - * parent partitions - * @param prev RDD to be coalesced - * @param maxPartitions number of desired partitions in the coalesced RDD - * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance - */ -class CoalescedRDD[T: ClassManifest]( - @transient var prev: RDD[T], - maxPartitions: Int, - balanceSlack: Double = 0.10) - extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies - - override def getPartitions: Array[Partition] = { - val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack) - - pc.run().zipWithIndex.map { - case (pg, i) => - val ids = pg.arr.map(_.index).toArray - new CoalescedRDDPartition(i, prev, ids, pg.prefLoc) - } - } - - override def compute(partition: Partition, context: TaskContext): Iterator[T] = { - partition.asInstanceOf[CoalescedRDDPartition].parents.iterator.flatMap { parentPartition => - firstParent[T].iterator(parentPartition, context) - } - } - - override def getDependencies: Seq[Dependency[_]] = { - Seq(new NarrowDependency(prev) { - def getParents(id: Int): Seq[Int] = - partitions(id).asInstanceOf[CoalescedRDDPartition].parentsIndices - }) - } - - override def clearDependencies() { - super.clearDependencies() - prev = null - } - - /** - * Returns the preferred machine for the partition. If split is of type CoalescedRDDPartition, - * then the preferred machine will be one which most parent splits prefer too. - * @param partition - * @return the machine most preferred by split - */ - override def getPreferredLocations(partition: Partition): Seq[String] = { - List(partition.asInstanceOf[CoalescedRDDPartition].preferredLocation) - } -} - -/** - * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of - * this RDD computes one or more of the parent ones. It will produce exactly `maxPartitions` if the - * parent had more than maxPartitions, or fewer if the parent had fewer. - * - * This transformation is useful when an RDD with many partitions gets filtered into a smaller one, - * or to avoid having a large number of small tasks when processing a directory with many files. - * - * If there is no locality information (no preferredLocations) in the parent, then the coalescing - * is very simple: chunk parents that are close in the Array in chunks. - * If there is locality information, it proceeds to pack them with the following four goals: - * - * (1) Balance the groups so they roughly have the same number of parent partitions - * (2) Achieve locality per partition, i.e. find one machine which most parent partitions prefer - * (3) Be efficient, i.e. O(n) algorithm for n parent partitions (problem is likely NP-hard) - * (4) Balance preferred machines, i.e. avoid as much as possible picking the same preferred machine - * - * Furthermore, it is assumed that the parent RDD may have many partitions, e.g. 100 000. - * We assume the final number of desired partitions is small, e.g. less than 1000. - * - * The algorithm tries to assign unique preferred machines to each partition. If the number of - * desired partitions is greater than the number of preferred machines (can happen), it needs to - * start picking duplicate preferred machines. This is determined using coupon collector estimation - * (2n log(n)). The load balancing is done using power-of-two randomized bins-balls with one twist: - * it tries to also achieve locality. This is done by allowing a slack (balanceSlack) between two - * bins. If two bins are within the slack in terms of balance, the algorithm will assign partitions - * according to locality. (contact alig for questions) - * - */ - -private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) { - - def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size - def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean = - if (o1 == None) false else if (o2 == None) true else compare(o1.get, o2.get) - - val rnd = new scala.util.Random(7919) // keep this class deterministic - - // each element of groupArr represents one coalesced partition - val groupArr = ArrayBuffer[PartitionGroup]() - - // hash used to check whether some machine is already in groupArr - val groupHash = mutable.Map[String, ArrayBuffer[PartitionGroup]]() - - // hash used for the first maxPartitions (to avoid duplicates) - val initialHash = mutable.Set[Partition]() - - // determines the tradeoff between load-balancing the partitions sizes and their locality - // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality - val slack = (balanceSlack * prev.partitions.size).toInt - - var noLocality = true // if true if no preferredLocations exists for parent RDD - - // gets the *current* preferred locations from the DAGScheduler (as opposed to the static ones) - def currPrefLocs(part: Partition): Seq[String] = { - prev.context.getPreferredLocs(prev, part.index).map(tl => tl.host) - } - - // this class just keeps iterating and rotating infinitely over the partitions of the RDD - // next() returns the next preferred machine that a partition is replicated on - // the rotator first goes through the first replica copy of each partition, then second, third - // the iterators return type is a tuple: (replicaString, partition) - class LocationIterator(prev: RDD[_]) extends Iterator[(String, Partition)] { - - var it: Iterator[(String, Partition)] = resetIterator() - - override val isEmpty = !it.hasNext - - // initializes/resets to start iterating from the beginning - def resetIterator() = { - val iterators = (0 to 2).map( x => - prev.partitions.iterator.flatMap(p => { - if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None - } ) - ) - iterators.reduceLeft((x, y) => x ++ y) - } - - // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD - def hasNext(): Boolean = { !isEmpty } - - // return the next preferredLocation of some partition of the RDD - def next(): (String, Partition) = { - if (it.hasNext) - it.next() - else { - it = resetIterator() // ran out of preferred locations, reset and rotate to the beginning - it.next() - } - } - } - - /** - * Sorts and gets the least element of the list associated with key in groupHash - * The returned PartitionGroup is the least loaded of all groups that represent the machine "key" - * @param key string representing a partitioned group on preferred machine key - * @return Option of PartitionGroup that has least elements for key - */ - def getLeastGroupHash(key: String): Option[PartitionGroup] = { - groupHash.get(key).map(_.sortWith(compare).head) - } - - def addPartToPGroup(part: Partition, pgroup: PartitionGroup): Boolean = { - if (!initialHash.contains(part)) { - pgroup.arr += part // already assign this element - initialHash += part // needed to avoid assigning partitions to multiple buckets - true - } else { false } - } - - /** - * Initializes targetLen partition groups and assigns a preferredLocation - * This uses coupon collector to estimate how many preferredLocations it must rotate through - * until it has seen most of the preferred locations (2 * n log(n)) - * @param targetLen - */ - def setupGroups(targetLen: Int) { - val rotIt = new LocationIterator(prev) - - // deal with empty case, just create targetLen partition groups with no preferred location - if (!rotIt.hasNext()) { - (1 to targetLen).foreach(x => groupArr += PartitionGroup()) - return - } - - noLocality = false - - // number of iterations needed to be certain that we've seen most preferred locations - val expectedCoupons2 = 2 * (math.log(targetLen)*targetLen + targetLen + 0.5).toInt - var numCreated = 0 - var tries = 0 - - // rotate through until either targetLen unique/distinct preferred locations have been created - // OR we've rotated expectedCoupons2, in which case we have likely seen all preferred locations, - // i.e. likely targetLen >> number of preferred locations (more buckets than there are machines) - while (numCreated < targetLen && tries < expectedCoupons2) { - tries += 1 - val (nxt_replica, nxt_part) = rotIt.next() - if (!groupHash.contains(nxt_replica)) { - val pgroup = PartitionGroup(nxt_replica) - groupArr += pgroup - addPartToPGroup(nxt_part, pgroup) - groupHash += (nxt_replica -> (ArrayBuffer(pgroup))) // list in case we have multiple - numCreated += 1 - } - } - - while (numCreated < targetLen) { // if we don't have enough partition groups, create duplicates - var (nxt_replica, nxt_part) = rotIt.next() - val pgroup = PartitionGroup(nxt_replica) - groupArr += pgroup - groupHash.get(nxt_replica).get += pgroup - var tries = 0 - while (!addPartToPGroup(nxt_part, pgroup) && tries < targetLen) { // ensure at least one part - nxt_part = rotIt.next()._2 - tries += 1 - } - numCreated += 1 - } - - } - - /** - * Takes a parent RDD partition and decides which of the partition groups to put it in - * Takes locality into account, but also uses power of 2 choices to load balance - * It strikes a balance between the two use the balanceSlack variable - * @param p partition (ball to be thrown) - * @return partition group (bin to be put in) - */ - def pickBin(p: Partition): PartitionGroup = { - val pref = currPrefLocs(p).map(getLeastGroupHash(_)).sortWith(compare) // least loaded pref locs - val prefPart = if (pref == Nil) None else pref.head - - val r1 = rnd.nextInt(groupArr.size) - val r2 = rnd.nextInt(groupArr.size) - val minPowerOfTwo = if (groupArr(r1).size < groupArr(r2).size) groupArr(r1) else groupArr(r2) - if (prefPart== None) // if no preferred locations, just use basic power of two - return minPowerOfTwo - - val prefPartActual = prefPart.get - - if (minPowerOfTwo.size + slack <= prefPartActual.size) // more imbalance than the slack allows - return minPowerOfTwo // prefer balance over locality - else { - return prefPartActual // prefer locality over balance - } - } - - def throwBalls() { - if (noLocality) { // no preferredLocations in parent RDD, no randomization needed - if (maxPartitions > groupArr.size) { // just return prev.partitions - for ((p,i) <- prev.partitions.zipWithIndex) { - groupArr(i).arr += p - } - } else { // no locality available, then simply split partitions based on positions in array - for(i <- 0 until maxPartitions) { - val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt - val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt - (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) } - } - } - } else { - for (p <- prev.partitions if (!initialHash.contains(p))) { // throw every partition into group - pickBin(p).arr += p - } - } - } - - def getPartitions: Array[PartitionGroup] = groupArr.filter( pg => pg.size > 0).toArray - - /** - * Runs the packing algorithm and returns an array of PartitionGroups that if possible are - * load balanced and grouped by locality - * @return array of partition groups - */ - def run(): Array[PartitionGroup] = { - setupGroups(math.min(prev.partitions.length, maxPartitions)) // setup the groups (bins) - throwBalls() // assign partitions (balls) to each group (bins) - getPartitions - } -} - -private[spark] case class PartitionGroup(prefLoc: String = "") { - var arr = mutable.ArrayBuffer[Partition]() - - def size = arr.size -} diff --git a/core/src/main/scala/spark/rdd/EmptyRDD.scala b/core/src/main/scala/spark/rdd/EmptyRDD.scala deleted file mode 100644 index d7d4db5d30..0000000000 --- a/core/src/main/scala/spark/rdd/EmptyRDD.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} - - -/** - * An RDD that is empty, i.e. has no element in it. - */ -class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) { - - override def getPartitions: Array[Partition] = Array.empty - - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - throw new UnsupportedOperationException("empty RDD") - } -} diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala deleted file mode 100644 index 783508cfd1..0000000000 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{OneToOneDependency, RDD, Partition, TaskContext} - -private[spark] class FilteredRDD[T: ClassManifest]( - prev: RDD[T], - f: T => Boolean) - extends RDD[T](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override val partitioner = prev.partitioner // Since filter cannot change a partition's keys - - override def compute(split: Partition, context: TaskContext) = - firstParent[T].iterator(split, context).filter(f) -} diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala deleted file mode 100644 index ed75eac3ff..0000000000 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, Partition, TaskContext} - - -private[spark] -class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: T => TraversableOnce[U]) - extends RDD[U](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override def compute(split: Partition, context: TaskContext) = - firstParent[T].iterator(split, context).flatMap(f) -} diff --git a/core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala deleted file mode 100644 index a6bdce89d8..0000000000 --- a/core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{TaskContext, Partition, RDD} - - -private[spark] -class FlatMappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => TraversableOnce[U]) - extends RDD[(K, U)](prev) { - - override def getPartitions = firstParent[Product2[K, V]].partitions - - override val partitioner = firstParent[Product2[K, V]].partitioner - - override def compute(split: Partition, context: TaskContext) = { - firstParent[Product2[K, V]].iterator(split, context).flatMap { case Product2(k, v) => - f(v).map(x => (k, x)) - } - } -} diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala deleted file mode 100644 index 1573f8a289..0000000000 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, Partition, TaskContext} - -private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) - extends RDD[Array[T]](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override def compute(split: Partition, context: TaskContext) = - Array(firstParent[T].iterator(split, context).toArray).iterator -} diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala deleted file mode 100644 index e512423fd6..0000000000 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.io.EOFException -import java.util.NoSuchElementException - -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.FileInputFormat -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.mapred.InputSplit -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.mapred.RecordReader -import org.apache.hadoop.mapred.Reporter -import org.apache.hadoop.util.ReflectionUtils - -import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, SparkEnv, TaskContext} -import spark.util.NextIterator -import org.apache.hadoop.conf.{Configuration, Configurable} - - -/** - * A Spark split class that wraps around a Hadoop InputSplit. - */ -private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) - extends Partition { - - val inputSplit = new SerializableWritable[InputSplit](s) - - override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt - - override val index: Int = idx -} - -/** - * An RDD that reads a Hadoop dataset as specified by a JobConf (e.g. files in HDFS, the local file - * system, or S3, tables in HBase, etc). - */ -class HadoopRDD[K, V]( - sc: SparkContext, - @transient conf: JobConf, - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int) - extends RDD[(K, V)](sc, Nil) with Logging { - - // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) - - override def getPartitions: Array[Partition] = { - val env = SparkEnv.get - env.hadoop.addCredentials(conf) - val inputFormat = createInputFormat(conf) - if (inputFormat.isInstanceOf[Configurable]) { - inputFormat.asInstanceOf[Configurable].setConf(conf) - } - val inputSplits = inputFormat.getSplits(conf, minSplits) - val array = new Array[Partition](inputSplits.size) - for (i <- 0 until inputSplits.size) { - array(i) = new HadoopPartition(id, i, inputSplits(i)) - } - array - } - - def createInputFormat(conf: JobConf): InputFormat[K, V] = { - ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) - .asInstanceOf[InputFormat[K, V]] - } - - override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopPartition] - logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null - - val conf = confBroadcast.value.value - val fmt = createInputFormat(conf) - if (fmt.isInstanceOf[Configurable]) { - fmt.asInstanceOf[Configurable].setConf(conf) - } - reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback{ () => closeIfNeeded() } - - val key: K = reader.createKey() - val value: V = reader.createValue() - - override def getNext() = { - try { - finished = !reader.next(key, value) - } catch { - case eof: EOFException => - finished = true - } - (key, value) - } - - override def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - // TODO: Filtering out "localhost" in case of file:// URLs - val hadoopSplit = split.asInstanceOf[HadoopPartition] - hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") - } - - override def checkpoint() { - // Do nothing. Hadoop RDD should not be checkpointed. - } - - def getConf: Configuration = confBroadcast.value.value -} diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala deleted file mode 100644 index 59132437d2..0000000000 --- a/core/src/main/scala/spark/rdd/JdbcRDD.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.sql.{Connection, ResultSet} - -import spark.{Logging, Partition, RDD, SparkContext, TaskContext} -import spark.util.NextIterator - -private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { - override def index = idx -} - -/** - * An RDD that executes an SQL query on a JDBC connection and reads results. - * For usage example, see test case JdbcRDDSuite. - * - * @param getConnection a function that returns an open Connection. - * The RDD takes care of closing the connection. - * @param sql the text of the query. - * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" - * @param lowerBound the minimum value of the first placeholder - * @param upperBound the maximum value of the second placeholder - * The lower and upper bounds are inclusive. - * @param numPartitions the number of partitions. - * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, - * the query would be executed twice, once with (1, 10) and once with (11, 20) - * @param mapRow a function from a ResultSet to a single row of the desired result type(s). - * This should only call getInt, getString, etc; the RDD takes care of calling next. - * The default maps a ResultSet to an array of Object. - */ -class JdbcRDD[T: ClassManifest]( - sc: SparkContext, - getConnection: () => Connection, - sql: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int, - mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _) - extends RDD[T](sc, Nil) with Logging { - - override def getPartitions: Array[Partition] = { - // bounds are inclusive, hence the + 1 here and - 1 on end - val length = 1 + upperBound - lowerBound - (0 until numPartitions).map(i => { - val start = lowerBound + ((i * length) / numPartitions).toLong - val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 - new JdbcPartition(i, start, end) - }).toArray - } - - override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { - context.addOnCompleteCallback{ () => closeIfNeeded() } - val part = thePart.asInstanceOf[JdbcPartition] - val conn = getConnection() - val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - - // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results, - // rather than pulling entire resultset into memory. - // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html - if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { - stmt.setFetchSize(Integer.MIN_VALUE) - logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") - } - - stmt.setLong(1, part.lower) - stmt.setLong(2, part.upper) - val rs = stmt.executeQuery() - - override def getNext: T = { - if (rs.next()) { - mapRow(rs) - } else { - finished = true - null.asInstanceOf[T] - } - } - - override def close() { - try { - if (null != rs && ! rs.isClosed()) rs.close() - } catch { - case e: Exception => logWarning("Exception closing resultset", e) - } - try { - if (null != stmt && ! stmt.isClosed()) stmt.close() - } catch { - case e: Exception => logWarning("Exception closing statement", e) - } - try { - if (null != conn && ! stmt.isClosed()) conn.close() - logInfo("closed connection") - } catch { - case e: Exception => logWarning("Exception closing connection", e) - } - } - } -} - -object JdbcRDD { - def resultSetToObjectArray(rs: ResultSet) = { - Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) - } -} diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala deleted file mode 100644 index af8f0a112f..0000000000 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, Partition, TaskContext} - - -private[spark] -class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: Iterator[T] => Iterator[U], - preservesPartitioning: Boolean = false) - extends RDD[U](prev) { - - override val partitioner = - if (preservesPartitioning) firstParent[T].partitioner else None - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override def compute(split: Partition, context: TaskContext) = - f(firstParent[T].iterator(split, context)) -} diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala deleted file mode 100644 index 3b4e9518fd..0000000000 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, Partition, TaskContext} - - -/** - * A variant of the MapPartitionsRDD that passes the partition index into the - * closure. This can be used to generate or collect partition specific - * information such as the number of tuples in a partition. - */ -private[spark] -class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean - ) extends RDD[U](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override val partitioner = if (preservesPartitioning) prev.partitioner else None - - override def compute(split: Partition, context: TaskContext) = - f(split.index, firstParent[T].iterator(split, context)) -} diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala deleted file mode 100644 index 8b411dd85d..0000000000 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, Partition, TaskContext} - -private[spark] -class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U) - extends RDD[U](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override def compute(split: Partition, context: TaskContext) = - firstParent[T].iterator(split, context).map(f) -} diff --git a/core/src/main/scala/spark/rdd/MappedValuesRDD.scala b/core/src/main/scala/spark/rdd/MappedValuesRDD.scala deleted file mode 100644 index 8334e3b557..0000000000 --- a/core/src/main/scala/spark/rdd/MappedValuesRDD.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - - -import spark.{TaskContext, Partition, RDD} - -private[spark] -class MappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => U) - extends RDD[(K, U)](prev) { - - override def getPartitions = firstParent[Product2[K, U]].partitions - - override val partitioner = firstParent[Product2[K, U]].partitioner - - override def compute(split: Partition, context: TaskContext): Iterator[(K, U)] = { - firstParent[Product2[K, V]].iterator(split, context).map { case Product2(k ,v) => (k, f(v)) } - } -} diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala deleted file mode 100644 index b1877dc06e..0000000000 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.text.SimpleDateFormat -import java.util.Date - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce._ - -import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext} - - -private[spark] -class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) - extends Partition { - - val serializableHadoopSplit = new SerializableWritable(rawSplit) - - override def hashCode(): Int = (41 * (41 + rddId) + index) -} - -class NewHadoopRDD[K, V]( - sc : SparkContext, - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - @transient conf: Configuration) - extends RDD[(K, V)](sc, Nil) - with SparkHadoopMapReduceUtil - with Logging { - - // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) - // private val serializableConf = new SerializableWritable(conf) - - private val jobtrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - formatter.format(new Date()) - } - - @transient private val jobId = new JobID(jobtrackerId, id) - - override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance - if (inputFormat.isInstanceOf[Configurable]) { - inputFormat.asInstanceOf[Configurable].setConf(conf) - } - val jobContext = newJobContext(conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[Partition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } - - override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - if (format.isInstanceOf[Configurable]) { - format.asInstanceOf[Configurable].setConf(conf) - } - val reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => close()) - - var havePair = false - var finished = false - - override def hasNext: Boolean = { - if (!finished && !havePair) { - finished = !reader.nextKeyValue - havePair = !finished - } - !finished - } - - override def next: (K, V) = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - havePair = false - return (reader.getCurrentKey, reader.getCurrentValue) - } - - private def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - val theSplit = split.asInstanceOf[NewHadoopPartition] - theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") - } - - def getConf: Configuration = confBroadcast.value.value -} - diff --git a/core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala deleted file mode 100644 index 9154b76035..0000000000 --- a/core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RangePartitioner, Logging, RDD} - -/** - * Extra functions available on RDDs of (key, value) pairs where the key is sortable through - * an implicit conversion. Import `spark.SparkContext._` at the top of your program to use these - * functions. They will work with any key type that has a `scala.math.Ordered` implementation. - */ -class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, - V: ClassManifest, - P <: Product2[K, V] : ClassManifest]( - self: RDD[P]) - extends Logging with Serializable { - - /** - * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling - * `collect` or `save` on the resulting RDD will return or output an ordered list of records - * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in - * order of the keys). - */ - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { - val part = new RangePartitioner(numPartitions, self, ascending) - val shuffled = new ShuffledRDD[K, V, P](self, part) - shuffled.mapPartitions(iter => { - val buf = iter.toArray - if (ascending) { - buf.sortWith((x, y) => x._1 < y._1).iterator - } else { - buf.sortWith((x, y) => x._1 > y._1).iterator - } - }, preservesPartitioning = true) - } -} diff --git a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala deleted file mode 100644 index 33079cd539..0000000000 --- a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import scala.collection.immutable.NumericRange -import scala.collection.mutable.ArrayBuffer -import scala.collection.Map -import spark._ -import java.io._ -import scala.Serializable - -private[spark] class ParallelCollectionPartition[T: ClassManifest]( - var rddId: Long, - var slice: Int, - var values: Seq[T]) - extends Partition with Serializable { - - def iterator: Iterator[T] = values.iterator - - override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt - - override def equals(other: Any): Boolean = other match { - case that: ParallelCollectionPartition[_] => (this.rddId == that.rddId && this.slice == that.slice) - case _ => false - } - - override def index: Int = slice - - @throws(classOf[IOException]) - private def writeObject(out: ObjectOutputStream): Unit = { - - val sfactory = SparkEnv.get.serializer - - // Treat java serializer with default action rather than going thru serialization, to avoid a - // separate serialization header. - - sfactory match { - case js: JavaSerializer => out.defaultWriteObject() - case _ => - out.writeLong(rddId) - out.writeInt(slice) - - val ser = sfactory.newInstance() - Utils.serializeViaNestedStream(out, ser)(_.writeObject(values)) - } - } - - @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream): Unit = { - - val sfactory = SparkEnv.get.serializer - sfactory match { - case js: JavaSerializer => in.defaultReadObject() - case _ => - rddId = in.readLong() - slice = in.readInt() - - val ser = sfactory.newInstance() - Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject()) - } - } -} - -private[spark] class ParallelCollectionRDD[T: ClassManifest]( - @transient sc: SparkContext, - @transient data: Seq[T], - numSlices: Int, - locationPrefs: Map[Int, Seq[String]]) - extends RDD[T](sc, Nil) { - // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets - // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split - // instead. - // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. - - override def getPartitions: Array[Partition] = { - val slices = ParallelCollectionRDD.slice(data, numSlices).toArray - slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray - } - - override def compute(s: Partition, context: TaskContext) = - s.asInstanceOf[ParallelCollectionPartition[T]].iterator - - override def getPreferredLocations(s: Partition): Seq[String] = { - locationPrefs.getOrElse(s.index, Nil) - } -} - -private object ParallelCollectionRDD { - /** - * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range - * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes - * it efficient to run Spark over RDDs representing large sets of numbers. - */ - def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { - if (numSlices < 1) { - throw new IllegalArgumentException("Positive number of slices required") - } - seq match { - case r: Range.Inclusive => { - val sign = if (r.step < 0) { - -1 - } else { - 1 - } - slice(new Range( - r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) - } - case r: Range => { - (0 until numSlices).map(i => { - val start = ((i * r.length.toLong) / numSlices).toInt - val end = (((i + 1) * r.length.toLong) / numSlices).toInt - new Range(r.start + start * r.step, r.start + end * r.step, r.step) - }).asInstanceOf[Seq[Seq[T]]] - } - case nr: NumericRange[_] => { - // For ranges of Long, Double, BigInteger, etc - val slices = new ArrayBuffer[Seq[T]](numSlices) - val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything - var r = nr - for (i <- 0 until numSlices) { - slices += r.take(sliceSize).asInstanceOf[Seq[T]] - r = r.drop(sliceSize) - } - slices - } - case _ => { - val array = seq.toArray // To prevent O(n^2) operations for List etc - (0 until numSlices).map(i => { - val start = ((i * array.length.toLong) / numSlices).toInt - val end = (((i + 1) * array.length.toLong) / numSlices).toInt - array.slice(start, end).toSeq - }) - } - } - } -} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala deleted file mode 100644 index d8700becb0..0000000000 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{NarrowDependency, RDD, SparkEnv, Partition, TaskContext} - - -class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends Partition { - override val index = idx -} - - -/** - * Represents a dependency between the PartitionPruningRDD and its parent. In this - * case, the child RDD contains a subset of partitions of the parents'. - */ -class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) - extends NarrowDependency[T](rdd) { - - @transient - val partitions: Array[Partition] = rdd.partitions.zipWithIndex - .filter(s => partitionFilterFunc(s._2)) - .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - - override def getParents(partitionId: Int) = List(partitions(partitionId).index) -} - - -/** - * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on - * all partitions. An example use case: If we know the RDD is partitioned by range, - * and the execution DAG has a filter on the key, we can avoid launching tasks - * on partitions that don't have the range covering the key. - */ -class PartitionPruningRDD[T: ClassManifest]( - @transient prev: RDD[T], - @transient partitionFilterFunc: Int => Boolean) - extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - - override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) - - override protected def getPartitions: Array[Partition] = - getDependencies.head.asInstanceOf[PruneDependency[T]].partitions -} - - -object PartitionPruningRDD { - - /** - * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD - * when its type T is not known at compile time. - */ - def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { - new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassManifest) - } -} diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala deleted file mode 100644 index 2cefdc78b0..0000000000 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.io.PrintWriter -import java.util.StringTokenizer - -import scala.collection.Map -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.io.Source - -import spark.{RDD, SparkEnv, Partition, TaskContext} -import spark.broadcast.Broadcast - - -/** - * An RDD that pipes the contents of each parent partition through an external command - * (printing them one per line) and returns the output as a collection of strings. - */ -class PipedRDD[T: ClassManifest]( - prev: RDD[T], - command: Seq[String], - envVars: Map[String, String], - printPipeContext: (String => Unit) => Unit, - printRDDElement: (T, String => Unit) => Unit) - extends RDD[String](prev) { - - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this( - prev: RDD[T], - command: String, - envVars: Map[String, String] = Map(), - printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null) = - this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement) - - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override def compute(split: Partition, context: TaskContext): Iterator[String] = { - val pb = new ProcessBuilder(command) - // Add the environmental variables to the process. - val currentEnvVars = pb.environment() - envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } - - val proc = pb.start() - val env = SparkEnv.get - - // Start a thread to print the process's stderr to ours - new Thread("stderr reader for " + command) { - override def run() { - for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { - System.err.println(line) - } - } - }.start() - - // Start a thread to feed the process input from our parent's iterator - new Thread("stdin writer for " + command) { - override def run() { - SparkEnv.set(env) - val out = new PrintWriter(proc.getOutputStream) - - // input the pipe context firstly - if (printPipeContext != null) { - printPipeContext(out.println(_)) - } - for (elem <- firstParent[T].iterator(split, context)) { - if (printRDDElement != null) { - printRDDElement(elem, out.println(_)) - } else { - out.println(elem) - } - } - out.close() - } - }.start() - - // Return an iterator that read lines from the process's stdout - val lines = Source.fromInputStream(proc.getInputStream).getLines - return new Iterator[String] { - def next() = lines.next() - def hasNext = { - if (lines.hasNext) { - true - } else { - val exitStatus = proc.waitFor() - if (exitStatus != 0) { - throw new Exception("Subprocess exited with status " + exitStatus) - } - false - } - } - } - } -} - -object PipedRDD { - // Split a string into words using a standard StringTokenizer - def tokenize(command: String): Seq[String] = { - val buf = new ArrayBuffer[String] - val tok = new StringTokenizer(command) - while(tok.hasMoreElements) { - buf += tok.nextToken() - } - buf - } -} diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala deleted file mode 100644 index 574c9b141d..0000000000 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.util.Random - -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand - -import spark.{RDD, Partition, TaskContext} - -private[spark] -class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { - override val index: Int = prev.index -} - -class SampledRDD[T: ClassManifest]( - prev: RDD[T], - withReplacement: Boolean, - frac: Double, - seed: Int) - extends RDD[T](prev) { - - override def getPartitions: Array[Partition] = { - val rg = new Random(seed) - firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt)) - } - - override def getPreferredLocations(split: Partition): Seq[String] = - firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev) - - override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { - val split = splitIn.asInstanceOf[SampledRDDPartition] - if (withReplacement) { - // For large datasets, the expected number of occurrences of each element in a sample with - // replacement is Poisson(frac). We use that to get a count for each element. - val poisson = new Poisson(frac, new DRand(split.seed)) - firstParent[T].iterator(split.prev, context).flatMap { element => - val count = poisson.nextInt() - if (count == 0) { - Iterator.empty // Avoid object allocation when we return 0 items, which is quite often - } else { - Iterator.fill(count)(element) - } - } - } else { // Sampling without replacement - val rand = new Random(split.seed) - firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) - } - } -} diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala deleted file mode 100644 index 51c05af064..0000000000 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{Dependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext} - - -private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { - override val index = idx - override def hashCode(): Int = idx -} - -/** - * The resulting RDD from a shuffle (e.g. repartitioning of data). - * @param prev the parent RDD. - * @param part the partitioner used to partition the RDD - * @tparam K the key class. - * @tparam V the value class. - */ -class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest]( - @transient var prev: RDD[P], - part: Partitioner) - extends RDD[P](prev.context, Nil) { - - private var serializerClass: String = null - - def setSerializer(cls: String): ShuffledRDD[K, V, P] = { - serializerClass = cls - this - } - - override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializerClass)) - } - - override val partitioner = Some(part) - - override def getPartitions: Array[Partition] = { - Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) - } - - override def compute(split: Partition, context: TaskContext): Iterator[P] = { - val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics, - SparkEnv.get.serializerManager.get(serializerClass)) - } - - override def clearDependencies() { - super.clearDependencies() - prev = null - } -} diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala deleted file mode 100644 index dadef5e17d..0000000000 --- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import spark.RDD -import spark.Partitioner -import spark.Dependency -import spark.TaskContext -import spark.Partition -import spark.SparkEnv -import spark.ShuffleDependency -import spark.OneToOneDependency - - -/** - * An optimized version of cogroup for set difference/subtraction. - * - * It is possible to implement this operation with just `cogroup`, but - * that is less efficient because all of the entries from `rdd2`, for - * both matching and non-matching values in `rdd1`, are kept in the - * JHashMap until the end. - * - * With this implementation, only the entries from `rdd1` are kept in-memory, - * and the entries from `rdd2` are essentially streamed, as we only need to - * touch each once to decide if the value needs to be removed. - * - * This is particularly helpful when `rdd1` is much smaller than `rdd2`, as - * you can use `rdd1`'s partitioner/partition size and not worry about running - * out of memory because of the size of `rdd2`. - */ -private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest]( - @transient var rdd1: RDD[_ <: Product2[K, V]], - @transient var rdd2: RDD[_ <: Product2[K, W]], - part: Partitioner) - extends RDD[(K, V)](rdd1.context, Nil) { - - private var serializerClass: String = null - - def setSerializer(cls: String): SubtractedRDD[K, V, W] = { - serializerClass = cls - this - } - - override def getDependencies: Seq[Dependency[_]] = { - Seq(rdd1, rdd2).map { rdd => - if (rdd.partitioner == Some(part)) { - logDebug("Adding one-to-one dependency with " + rdd) - new OneToOneDependency(rdd) - } else { - logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd, part, serializerClass) - } - } - } - - override def getPartitions: Array[Partition] = { - val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { - // Each CoGroupPartition will depend on rdd1 and rdd2 - array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => - dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) - case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) - } - }.toArray) - } - array - } - - override val partitioner = Some(part) - - override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { - val partition = p.asInstanceOf[CoGroupPartition] - val serializer = SparkEnv.get.serializerManager.get(serializerClass) - val map = new JHashMap[K, ArrayBuffer[V]] - def getSeq(k: K): ArrayBuffer[V] = { - val seq = map.get(k) - if (seq != null) { - seq - } else { - val seq = new ArrayBuffer[V]() - map.put(k, seq) - seq - } - } - def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - } - case ShuffleCoGroupSplitDep(shuffleId) => { - val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context.taskMetrics, serializer) - iter.foreach(op) - } - } - // the first dep is rdd1; add all values to the map - integrate(partition.deps(0), t => getSeq(t._1) += t._2) - // the second dep is rdd2; remove all of its keys - integrate(partition.deps(1), t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - } - -} diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala deleted file mode 100644 index 2776826f18..0000000000 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import scala.collection.mutable.ArrayBuffer -import spark.{Dependency, RangeDependency, RDD, SparkContext, Partition, TaskContext} -import java.io.{ObjectOutputStream, IOException} - -private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) - extends Partition { - - var split: Partition = rdd.partitions(splitIndex) - - def iterator(context: TaskContext) = rdd.iterator(split, context) - - def preferredLocations() = rdd.preferredLocations(split) - - override val index: Int = idx - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent split at the time of task serialization - split = rdd.partitions(splitIndex) - oos.defaultWriteObject() - } -} - -class UnionRDD[T: ClassManifest]( - sc: SparkContext, - @transient var rdds: Seq[RDD[T]]) - extends RDD[T](sc, Nil) { // Nil since we implement getDependencies - - override def getPartitions: Array[Partition] = { - val array = new Array[Partition](rdds.map(_.partitions.size).sum) - var pos = 0 - for (rdd <- rdds; split <- rdd.partitions) { - array(pos) = new UnionPartition(pos, rdd, split.index) - pos += 1 - } - array - } - - override def getDependencies: Seq[Dependency[_]] = { - val deps = new ArrayBuffer[Dependency[_]] - var pos = 0 - for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) - pos += rdd.partitions.size - } - deps - } - - override def compute(s: Partition, context: TaskContext): Iterator[T] = - s.asInstanceOf[UnionPartition[T]].iterator(context) - - override def getPreferredLocations(s: Partition): Seq[String] = - s.asInstanceOf[UnionPartition[T]].preferredLocations() -} diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala deleted file mode 100644 index 9a0831bd89..0000000000 --- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext} -import java.io.{ObjectOutputStream, IOException} - -private[spark] class ZippedPartitionsPartition( - idx: Int, - @transient rdds: Seq[RDD[_]]) - extends Partition { - - override val index: Int = idx - var partitionValues = rdds.map(rdd => rdd.partitions(idx)) - def partitions = partitionValues - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent split at the time of task serialization - partitionValues = rdds.map(rdd => rdd.partitions(idx)) - oos.defaultWriteObject() - } -} - -abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( - sc: SparkContext, - var rdds: Seq[RDD[_]]) - extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { - - override def getPartitions: Array[Partition] = { - val sizes = rdds.map(x => x.partitions.size) - if (!sizes.forall(x => x == sizes(0))) { - throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") - } - val array = new Array[Partition](sizes(0)) - for (i <- 0 until sizes(0)) { - array(i) = new ZippedPartitionsPartition(i, rdds) - } - array - } - - override def getPreferredLocations(s: Partition): Seq[String] = { - val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions - val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) } - // Check whether there are any hosts that match all RDDs; otherwise return the union - val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) - if (!exactMatchLocations.isEmpty) { - exactMatchLocations - } else { - prefs.flatten.distinct - } - } - - override def clearDependencies() { - super.clearDependencies() - rdds = null - } -} - -class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( - sc: SparkContext, - f: (Iterator[A], Iterator[B]) => Iterator[V], - var rdd1: RDD[A], - var rdd2: RDD[B]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { - - override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions - f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - } -} - -class ZippedPartitionsRDD3 - [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( - sc: SparkContext, - f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], - var rdd1: RDD[A], - var rdd2: RDD[B], - var rdd3: RDD[C]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { - - override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions - f(rdd1.iterator(partitions(0), context), - rdd2.iterator(partitions(1), context), - rdd3.iterator(partitions(2), context)) - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - rdd3 = null - } -} - -class ZippedPartitionsRDD4 - [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( - sc: SparkContext, - f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], - var rdd1: RDD[A], - var rdd2: RDD[B], - var rdd3: RDD[C], - var rdd4: RDD[D]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { - - override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions - f(rdd1.iterator(partitions(0), context), - rdd2.iterator(partitions(1), context), - rdd3.iterator(partitions(2), context), - rdd4.iterator(partitions(3), context)) - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - rdd3 = null - rdd4 = null - } -} diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala deleted file mode 100644 index 4074e50e44..0000000000 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext} -import java.io.{ObjectOutputStream, IOException} - - -private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest]( - idx: Int, - @transient rdd1: RDD[T], - @transient rdd2: RDD[U] - ) extends Partition { - - var partition1 = rdd1.partitions(idx) - var partition2 = rdd2.partitions(idx) - override val index: Int = idx - - def partitions = (partition1, partition2) - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent partition at the time of task serialization - partition1 = rdd1.partitions(idx) - partition2 = rdd2.partitions(idx) - oos.defaultWriteObject() - } -} - -class ZippedRDD[T: ClassManifest, U: ClassManifest]( - sc: SparkContext, - var rdd1: RDD[T], - var rdd2: RDD[U]) - extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) { - - override def getPartitions: Array[Partition] = { - if (rdd1.partitions.size != rdd2.partitions.size) { - throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") - } - val array = new Array[Partition](rdd1.partitions.size) - for (i <- 0 until rdd1.partitions.size) { - array(i) = new ZippedPartition(i, rdd1, rdd2) - } - array - } - - override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = { - val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions - rdd1.iterator(partition1, context).zip(rdd2.iterator(partition2, context)) - } - - override def getPreferredLocations(s: Partition): Seq[String] = { - val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions - val pref1 = rdd1.preferredLocations(partition1) - val pref2 = rdd2.preferredLocations(partition2) - // Check whether there are any hosts that match both RDDs; otherwise return the union - val exactMatchLocations = pref1.intersect(pref2) - if (!exactMatchLocations.isEmpty) { - exactMatchLocations - } else { - (pref1 ++ pref2).distinct - } - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - } -} diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala deleted file mode 100644 index fecc3e9648..0000000000 --- a/core/src/main/scala/spark/scheduler/ActiveJob.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.TaskContext - -import java.util.Properties - -/** - * Tracks information about an active job in the DAGScheduler. - */ -private[spark] class ActiveJob( - val jobId: Int, - val finalStage: Stage, - val func: (TaskContext, Iterator[_]) => _, - val partitions: Array[Int], - val callSite: String, - val listener: JobListener, - val properties: Properties) { - - val numPartitions = partitions.length - val finished = Array.fill[Boolean](numPartitions)(false) - var numFinished = 0 -} diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala deleted file mode 100644 index 7275bd346a..0000000000 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ /dev/null @@ -1,849 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.io.NotSerializableException -import java.util.Properties -import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} - -import spark._ -import spark.executor.TaskMetrics -import spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import spark.scheduler.cluster.TaskInfo -import spark.storage.{BlockManager, BlockManagerMaster} -import spark.util.{MetadataCleaner, TimeStampedHashMap} - -/** - * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of - * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a - * minimal schedule to run the job. It then submits stages as TaskSets to an underlying - * TaskScheduler implementation that runs them on the cluster. - * - * In addition to coming up with a DAG of stages, this class also determines the preferred - * locations to run each task on, based on the current cache status, and passes these to the - * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being - * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are - * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task - * a small number of times before cancelling the whole stage. - * - * THREADING: This class runs all its logic in a single thread executing the run() method, to which - * events are submitted using a synchonized queue (eventQueue). The public API methods, such as - * runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods - * should be private. - */ -private[spark] -class DAGScheduler( - taskSched: TaskScheduler, - mapOutputTracker: MapOutputTracker, - blockManagerMaster: BlockManagerMaster, - env: SparkEnv) - extends TaskSchedulerListener with Logging { - - def this(taskSched: TaskScheduler) { - this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) - } - taskSched.setListener(this) - - // Called by TaskScheduler to report task's starting. - override def taskStarted(task: Task[_], taskInfo: TaskInfo) { - eventQueue.put(BeginEvent(task, taskInfo)) - } - - // Called by TaskScheduler to report task completions or failures. - override def taskEnded( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { - eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) - } - - // Called by TaskScheduler when an executor fails. - override def executorLost(execId: String) { - eventQueue.put(ExecutorLost(execId)) - } - - // Called by TaskScheduler when a host is added - override def executorGained(execId: String, host: String) { - eventQueue.put(ExecutorGained(execId, host)) - } - - // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. - override def taskSetFailed(taskSet: TaskSet, reason: String) { - eventQueue.put(TaskSetFailed(taskSet, reason)) - } - - // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; - // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one - // as more failure events come in - val RESUBMIT_TIMEOUT = 50L - - // The time, in millis, to wake up between polls of the completion queue in order to potentially - // resubmit failed stages - val POLL_TIMEOUT = 10L - - private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] - - val nextJobId = new AtomicInteger(0) - - val nextStageId = new AtomicInteger(0) - - val stageIdToStage = new TimeStampedHashMap[Int, Stage] - - val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] - - private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] - - private val listenerBus = new SparkListenerBus() - - // Contains the locations that each RDD's partitions are cached on - private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] - - // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with - // every task. When we detect a node failing, we note the current epoch number and failed - // executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask results. - // - // TODO: Garbage collect information about failure epochs when we know there are no more - // stray messages to detect. - val failedEpoch = new HashMap[String, Long] - - val idToActiveJob = new HashMap[Int, ActiveJob] - - val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done - val running = new HashSet[Stage] // Stages we are running right now - val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures - val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage - var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits - - val activeJobs = new HashSet[ActiveJob] - val resultStageToJob = new HashMap[Stage, ActiveJob] - - val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) - - // Start a thread to run the DAGScheduler event loop - def start() { - new Thread("DAGScheduler") { - setDaemon(true) - override def run() { - DAGScheduler.this.run() - } - }.start() - } - - def addSparkListener(listener: SparkListener) { - listenerBus.addListener(listener) - } - - private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { - if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray - val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) - cacheLocs(rdd.id) = blockIds.map { id => - locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) - } - } - cacheLocs(rdd.id) - } - - private def clearCacheLocs() { - cacheLocs.clear() - } - - /** - * Get or create a shuffle map stage for the given shuffle dependency's map side. - * The jobId value passed in will be used if the stage doesn't already exist with - * a lower jobId (jobId always increases across jobs.) - */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = { - shuffleToMapStage.get(shuffleDep.shuffleId) match { - case Some(stage) => stage - case None => - val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId) - shuffleToMapStage(shuffleDep.shuffleId) = stage - stage - } - } - - /** - * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or - * as a result stage for the final RDD used directly in an action. The stage will also be - * associated with the provided jobId. - */ - private def newStage( - rdd: RDD[_], - shuffleDep: Option[ShuffleDependency[_,_]], - jobId: Int, - callSite: Option[String] = None) - : Stage = - { - if (shuffleDep != None) { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of partitions is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) - } - val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) - stageIdToStage(id) = stage - stageToInfos(stage) = StageInfo(stage) - stage - } - - /** - * Get or create the list of parent stages for a given RDD. The stages will be assigned the - * provided jobId if they haven't already been created with a lower jobId. - */ - private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { - val parents = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - def visit(r: RDD[_]) { - if (!visited(r)) { - visited += r - // Kind of ugly: need to register RDDs with the cache here since - // we can't do it in its constructor because # of partitions is unknown - for (dep <- r.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_] => - parents += getShuffleMapStage(shufDep, jobId) - case _ => - visit(dep.rdd) - } - } - } - } - visit(rdd) - parents.toList - } - - private def getMissingParentStages(stage: Stage): List[Stage] = { - val missing = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - def visit(rdd: RDD[_]) { - if (!visited(rdd)) { - visited += rdd - if (getCacheLocs(rdd).contains(Nil)) { - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) - if (!mapStage.isAvailable) { - missing += mapStage - } - case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) - } - } - } - } - } - visit(stage.rdd) - missing.toList - } - - /** - * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a - * JobWaiter whose getResult() method will return the result of the job when it is complete. - * - * The job is assumed to have at least one partition; zero partition jobs should be handled - * without a JobSubmitted event. - */ - private[scheduler] def prepareJob[T, U: ClassManifest]( - finalRdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - callSite: String, - allowLocal: Boolean, - resultHandler: (Int, U) => Unit, - properties: Properties = null) - : (JobSubmitted, JobWaiter[U]) = - { - assert(partitions.size > 0) - val waiter = new JobWaiter(partitions.size, resultHandler) - val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, - properties) - (toSubmit, waiter) - } - - def runJob[T, U: ClassManifest]( - finalRdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - callSite: String, - allowLocal: Boolean, - resultHandler: (Int, U) => Unit, - properties: Properties = null) - { - if (partitions.size == 0) { - return - } - - // Check to make sure we are not launching a task on a partition that does not exist. - val maxPartitions = finalRdd.partitions.length - partitions.find(p => p >= maxPartitions).foreach { p => - throw new IllegalArgumentException( - "Attempting to access a non-existent partition: " + p + ". " + - "Total number of partitions: " + maxPartitions) - } - - val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob( - finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) - eventQueue.put(toSubmit) - waiter.awaitResult() match { - case JobSucceeded => {} - case JobFailed(exception: Exception, _) => - logInfo("Failed to run " + callSite) - throw exception - } - } - - def runApproximateJob[T, U, R]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - evaluator: ApproximateEvaluator[U, R], - callSite: String, - timeout: Long, - properties: Properties = null) - : PartialResult[R] = - { - val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) - val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val partitions = (0 until rdd.partitions.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) - listener.awaitResult() // Will throw an exception if the job fails - } - - /** - * Process one event retrieved from the event queue. - * Returns true if we should stop the event loop. - */ - private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { - event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => - val jobId = nextJobId.getAndIncrement() - val finalStage = newStage(finalRDD, None, jobId, Some(callSite)) - val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) - clearCacheLocs() - logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + - " output partitions (allowLocal=" + allowLocal + ")") - logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { - // Compute very short actions like first() or take() with no parent stages locally. - runLocally(job) - } else { - listenerBus.post(SparkListenerJobStart(job, properties)) - idToActiveJob(jobId) = job - activeJobs += job - resultStageToJob(finalStage) = job - submitStage(finalStage) - } - - case ExecutorGained(execId, host) => - handleExecutorGained(execId, host) - - case ExecutorLost(execId) => - handleExecutorLost(execId) - - case begin: BeginEvent => - listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo)) - - case completion: CompletionEvent => - listenerBus.post(SparkListenerTaskEnd( - completion.task, completion.reason, completion.taskInfo, completion.taskMetrics)) - handleTaskCompletion(completion) - - case TaskSetFailed(taskSet, reason) => - abortStage(stageIdToStage(taskSet.stageId), reason) - - case StopDAGScheduler => - // Cancel any active jobs - for (job <- activeJobs) { - val error = new SparkException("Job cancelled because SparkContext was shut down") - job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, None))) - } - return true - } - false - } - - /** - * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since - * the last fetch failure. - */ - private[scheduler] def resubmitFailedStages() { - logInfo("Resubmitting failed stages") - clearCacheLocs() - val failed2 = failed.toArray - failed.clear() - for (stage <- failed2.sortBy(_.jobId)) { - submitStage(stage) - } - } - - /** - * Check for waiting or failed stages which are now eligible for resubmission. - * Ordinarily run on every iteration of the event loop. - */ - private[scheduler] def submitWaitingStages() { - // TODO: We might want to run this less often, when we are sure that something has become - // runnable that wasn't before. - logTrace("Checking for newly runnable parent stages") - logTrace("running: " + running) - logTrace("waiting: " + waiting) - logTrace("failed: " + failed) - val waiting2 = waiting.toArray - waiting.clear() - for (stage <- waiting2.sortBy(_.jobId)) { - submitStage(stage) - } - } - - - /** - * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure - * events and responds by launching tasks. This runs in a dedicated thread and receives events - * via the eventQueue. - */ - private def run() { - SparkEnv.set(env) - - while (true) { - val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) - if (event != null) { - logDebug("Got event of type " + event.getClass.getName) - } - this.synchronized { // needed in case other threads makes calls into methods of this class - if (event != null) { - if (processEvent(event)) { - return - } - } - - val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability - // Periodically resubmit failed stages if some map output fetches have failed and we have - // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails, - // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at - // the same time, so we want to make sure we've identified all the reduce tasks that depend - // on the failed node. - if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { - resubmitFailedStages() - } else { - submitWaitingStages() - } - } - } - } - - /** - * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. - * We run the operation in a separate thread just in case it takes a bunch of time, so that we - * don't block the DAGScheduler event loop or other concurrent jobs. - */ - protected def runLocally(job: ActiveJob) { - logInfo("Computing the requested partition locally") - new Thread("Local computation of job " + job.jobId) { - override def run() { - runLocallyWithinThread(job) - } - }.start() - } - - // Broken out for easier testing in DAGSchedulerSuite. - protected def runLocallyWithinThread(job: ActiveJob) { - try { - SparkEnv.set(env) - val rdd = job.finalStage.rdd - val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) - try { - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - job.listener.taskSucceeded(0, result) - } finally { - taskContext.executeOnCompleteCallbacks() - } - } catch { - case e: Exception => - job.listener.jobFailed(e) - } - } - - /** Submits stage, but first recursively submits any missing parents. */ - private def submitStage(stage: Stage) { - logDebug("submitStage(" + stage + ")") - if (!waiting(stage) && !running(stage) && !failed(stage)) { - val missing = getMissingParentStages(stage).sortBy(_.id) - logDebug("missing: " + missing) - if (missing == Nil) { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") - submitMissingTasks(stage) - running += stage - } else { - for (parent <- missing) { - submitStage(parent) - } - waiting += stage - } - } - } - - /** Called when stage's parents are available and we can now do its task. */ - private def submitMissingTasks(stage: Stage) { - logDebug("submitMissingTasks(" + stage + ")") - // Get our pending tasks and remember them in our pendingTasks entry - val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) - myPending.clear() - var tasks = ArrayBuffer[Task[_]]() - if (stage.isShuffleMap) { - for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { - val locs = getPreferredLocs(stage.rdd, p) - tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) - } - } else { - // This is a final stage; figure out its job's missing partitions - val job = resultStageToJob(stage) - for (id <- 0 until job.numPartitions if !job.finished(id)) { - val partition = job.partitions(id) - val locs = getPreferredLocs(stage.rdd, partition) - tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) - } - } - // must be run listener before possible NotSerializableException - // should be "StageSubmitted" first and then "JobEnded" - val properties = idToActiveJob(stage.jobId).properties - listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties)) - - if (tasks.size > 0) { - // Preemptively serialize a task to make sure it can be serialized. We are catching this - // exception here because it would be fairly hard to catch the non-serializable exception - // down the road, where we have several different implementations for local scheduler and - // cluster schedulers. - try { - SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) - } catch { - case e: NotSerializableException => - abortStage(stage, e.toString) - running -= stage - return - } - - logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - myPending ++= tasks - logDebug("New pending tasks: " + myPending) - taskSched.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - if (!stage.submissionTime.isDefined) { - stage.submissionTime = Some(System.currentTimeMillis()) - } - } else { - logDebug("Stage " + stage + " is actually done; %b %d %d".format( - stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) - running -= stage - } - } - - /** - * Responds to a task finishing. This is called inside the event loop so it assumes that it can - * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. - */ - private def handleTaskCompletion(event: CompletionEvent) { - val task = event.task - val stage = stageIdToStage(task.stageId) - - def markStageAsFinished(stage: Stage) = { - val serviceTime = stage.submissionTime match { - case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0) - case _ => "Unkown" - } - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.completionTime = Some(System.currentTimeMillis) - listenerBus.post(StageCompleted(stageToInfos(stage))) - running -= stage - } - event.reason match { - case Success => - logInfo("Completed " + task) - if (event.accumUpdates != null) { - Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted - } - pendingTasks(stage) -= task - stageToInfos(stage).taskInfos += event.taskInfo -> event.taskMetrics - task match { - case rt: ResultTask[_, _] => - resultStageToJob.get(stage) match { - case Some(job) => - if (!job.finished(rt.outputId)) { - job.finished(rt.outputId) = true - job.numFinished += 1 - // If the whole job has finished, remove it - if (job.numFinished == job.numPartitions) { - idToActiveJob -= stage.jobId - activeJobs -= job - resultStageToJob -= stage - markStageAsFinished(stage) - listenerBus.post(SparkListenerJobEnd(job, JobSucceeded)) - } - job.listener.taskSucceeded(rt.outputId, event.result) - } - case None => - logInfo("Ignoring result from " + rt + " because its job has finished") - } - - case smt: ShuffleMapTask => - val status = event.result.asInstanceOf[MapStatus] - val execId = status.location.executorId - logDebug("ShuffleMapTask finished on " + execId) - if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) - } else { - stage.addOutputLoc(smt.partition, status) - } - if (running.contains(stage) && pendingTasks(stage).isEmpty) { - markStageAsFinished(stage) - logInfo("looking for newly runnable stages") - logInfo("running: " + running) - logInfo("waiting: " + waiting) - logInfo("failed: " + failed) - if (stage.shuffleDep != None) { - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, - changeEpoch = true) - } - clearCacheLocs() - if (stage.outputLocs.count(_ == Nil) != 0) { - // Some tasks had failed; let's resubmit this stage - // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + stage + " (" + stage.name + - ") because some of its tasks had failed: " + - stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) - submitStage(stage) - } else { - val newlyRunnable = new ArrayBuffer[Stage] - for (stage <- waiting) { - logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) - } - for (stage <- waiting if getMissingParentStages(stage) == Nil) { - newlyRunnable += stage - } - waiting --= newlyRunnable - running ++= newlyRunnable - for (stage <- newlyRunnable.sortBy(_.id)) { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage) - } - } - } - } - - case Resubmitted => - logInfo("Resubmitted " + task + ", so marking it as still running") - pendingTasks(stage) += task - - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => - // Mark the stage that the reducer was in as unrunnable - val failedStage = stageIdToStage(task.stageId) - running -= failedStage - failed += failedStage - // TODO: Cancel running tasks in the stage - logInfo("Marking " + failedStage + " (" + failedStage.name + - ") for resubmision due to a fetch failure") - // Mark the map whose fetch failed as broken in the map stage - val mapStage = shuffleToMapStage(shuffleId) - if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } - logInfo("The failed fetch was from " + mapStage + " (" + mapStage.name + - "); marking it for resubmission") - failed += mapStage - // Remember that a fetch failed now; this is used to resubmit the broken - // stages later, after a small wait (to give other tasks the chance to fail) - lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock - // TODO: mark the executor as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, Some(task.epoch)) - } - - case ExceptionFailure(className, description, stackTrace, metrics) => - // Do nothing here, left up to the TaskScheduler to decide how to handle user failures - - case other => - // Unrecognized failure - abort all jobs depending on this stage - abortStage(stageIdToStage(task.stageId), task + " failed: " + other) - } - } - - /** - * Responds to an executor being lost. This is called inside the event loop, so it assumes it can - * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. - * - * Optionally the epoch during which the failure was caught can be passed to avoid allowing - * stray fetch failures from possibly retriggering the detection of a node as lost. - */ - private def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) { - val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) - if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { - failedEpoch(execId) = currentEpoch - logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) - blockManagerMaster.removeExecutor(execId) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleToMapStage) { - stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) - } - if (shuffleToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() - } - clearCacheLocs() - } else { - logDebug("Additional executor lost message for " + execId + - "(epoch " + currentEpoch + ")") - } - } - - private def handleExecutorGained(execId: String, host: String) { - // remove from failedEpoch(execId) ? - if (failedEpoch.contains(execId)) { - logInfo("Host gained which was in lost list earlier: " + host) - failedEpoch -= execId - } - } - - /** - * Aborts all jobs depending on a particular Stage. This is called in response to a task set - * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. - */ - private def abortStage(failedStage: Stage, reason: String) { - val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq - failedStage.completionTime = Some(System.currentTimeMillis()) - for (resultStage <- dependentStages) { - val job = resultStageToJob(resultStage) - val error = new SparkException("Job failed: " + reason) - job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) - idToActiveJob -= resultStage.jobId - activeJobs -= job - resultStageToJob -= resultStage - } - if (dependentStages.isEmpty) { - logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") - } - } - - /** - * Return true if one of stage's ancestors is target. - */ - private def stageDependsOn(stage: Stage, target: Stage): Boolean = { - if (stage == target) { - return true - } - val visitedRdds = new HashSet[RDD[_]] - val visitedStages = new HashSet[Stage] - def visit(rdd: RDD[_]) { - if (!visitedRdds(rdd)) { - visitedRdds += rdd - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) - if (!mapStage.isAvailable) { - visitedStages += mapStage - visit(mapStage.rdd) - } // Otherwise there's no need to follow the dependency back - case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) - } - } - } - } - visit(stage.rdd) - visitedRdds.contains(target.rdd) - } - - /** - * Synchronized method that might be called from other threads. - * @param rdd whose partitions are to be looked at - * @param partition to lookup locality information for - * @return list of machines that are preferred by the partition - */ - private[spark] - def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { - // If the partition is cached, return the cache locations - val cached = getCacheLocs(rdd)(partition) - if (!cached.isEmpty) { - return cached - } - // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList - if (!rddPrefs.isEmpty) { - return rddPrefs.map(host => TaskLocation(host)) - } - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. - rdd.dependencies.foreach(_ match { - case n: NarrowDependency[_] => - for (inPart <- n.getParents(partition)) { - val locs = getPreferredLocs(n.rdd, inPart) - if (locs != Nil) - return locs - } - case _ => - }) - Nil - } - - private def cleanup(cleanupTime: Long) { - var sizeBefore = stageIdToStage.size - stageIdToStage.clearOldValues(cleanupTime) - logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size) - - sizeBefore = shuffleToMapStage.size - shuffleToMapStage.clearOldValues(cleanupTime) - logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) - - sizeBefore = pendingTasks.size - pendingTasks.clearOldValues(cleanupTime) - logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) - - sizeBefore = stageToInfos.size - stageToInfos.clearOldValues(cleanupTime) - logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size) - } - - def stop() { - eventQueue.put(StopDAGScheduler) - metadataCleaner.cancel() - taskSched.stop() - } -} diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala deleted file mode 100644 index b8ba0e9239..0000000000 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.util.Properties - -import spark.scheduler.cluster.TaskInfo -import scala.collection.mutable.Map - -import spark._ -import spark.executor.TaskMetrics - -/** - * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue - * architecture where any thread can post an event (e.g. a task finishing or a new job being - * submitted) but there is a single "logic" thread that reads these events and takes decisions. - * This greatly simplifies synchronization. - */ -private[spark] sealed trait DAGSchedulerEvent - -private[spark] case class JobSubmitted( - finalRDD: RDD[_], - func: (TaskContext, Iterator[_]) => _, - partitions: Array[Int], - allowLocal: Boolean, - callSite: String, - listener: JobListener, - properties: Properties = null) - extends DAGSchedulerEvent - -private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent - -private[spark] case class CompletionEvent( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) - extends DAGSchedulerEvent - -private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent - -private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent - -private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent - -private[spark] case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerSource.scala deleted file mode 100644 index 98c4fb7e59..0000000000 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerSource.scala +++ /dev/null @@ -1,30 +0,0 @@ -package spark.scheduler - -import com.codahale.metrics.{Gauge,MetricRegistry} - -import spark.metrics.source.Source - -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "DAGScheduler" - - metricRegistry.register(MetricRegistry.name("stage", "failedStages", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.failed.size - }) - - metricRegistry.register(MetricRegistry.name("stage", "runningStages", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.running.size - }) - - metricRegistry.register(MetricRegistry.name("stage", "waitingStages", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.waiting.size - }) - - metricRegistry.register(MetricRegistry.name("job", "allJobs", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.nextJobId.get() - }) - - metricRegistry.register(MetricRegistry.name("job", "activeJobs", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.activeJobs.size - }) -} diff --git a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala deleted file mode 100644 index 8f1b9b29b5..0000000000 --- a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.{Logging, SparkEnv} -import scala.collection.immutable.Set -import org.apache.hadoop.mapred.{FileInputFormat, JobConf} -import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.mapreduce.Job -import org.apache.hadoop.conf.Configuration -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.collection.JavaConversions._ - - -/** - * Parses and holds information about inputFormat (and files) specified as a parameter. - */ -class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], - val path: String) extends Logging { - - var mapreduceInputFormat: Boolean = false - var mapredInputFormat: Boolean = false - - validate() - - override def toString(): String = { - "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path - } - - override def hashCode(): Int = { - var hashCode = inputFormatClazz.hashCode - hashCode = hashCode * 31 + path.hashCode - hashCode - } - - // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path - // .. which is fine, this is best case effort to remove duplicates - right ? - override def equals(other: Any): Boolean = other match { - case that: InputFormatInfo => { - // not checking config - that should be fine, right ? - this.inputFormatClazz == that.inputFormatClazz && - this.path == that.path - } - case _ => false - } - - private def validate() { - logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path) - - try { - if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { - logDebug("inputformat is from mapreduce package") - mapreduceInputFormat = true - } - else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { - logDebug("inputformat is from mapred package") - mapredInputFormat = true - } - else { - throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + - " is NOT a supported input format ? does not implement either of the supported hadoop api's") - } - } - catch { - case e: ClassNotFoundException => { - throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) - } - } - } - - - // This method does not expect failures, since validate has already passed ... - private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = { - val env = SparkEnv.get - val conf = new JobConf(configuration) - env.hadoop.addCredentials(conf) - FileInputFormat.setInputPaths(conf, path) - - val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = - ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ - org.apache.hadoop.mapreduce.InputFormat[_, _]] - val job = new Job(conf) - - val retval = new ArrayBuffer[SplitInfo]() - val list = instance.getSplits(job) - for (split <- list) { - retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) - } - - return retval.toSet - } - - // This method does not expect failures, since validate has already passed ... - private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = { - val env = SparkEnv.get - val jobConf = new JobConf(configuration) - env.hadoop.addCredentials(jobConf) - FileInputFormat.setInputPaths(jobConf, path) - - val instance: org.apache.hadoop.mapred.InputFormat[_, _] = - ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[ - org.apache.hadoop.mapred.InputFormat[_, _]] - - val retval = new ArrayBuffer[SplitInfo]() - instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach( - elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem) - ) - - return retval.toSet - } - - private def findPreferredLocations(): Set[SplitInfo] = { - logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + - ", inputFormatClazz : " + inputFormatClazz) - if (mapreduceInputFormat) { - return prefLocsFromMapreduceInputFormat() - } - else { - assert(mapredInputFormat) - return prefLocsFromMapredInputFormat() - } - } -} - - - - -object InputFormatInfo { - /** - Computes the preferred locations based on input(s) and returned a location to block map. - Typical use of this method for allocation would follow some algo like this - (which is what we currently do in YARN branch) : - a) For each host, count number of splits hosted on that host. - b) Decrement the currently allocated containers on that host. - c) Compute rack info for each host and update rack -> count map based on (b). - d) Allocate nodes based on (c) - e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node - (even if data locality on that is very high) : this is to prevent fragility of job if a single - (or small set of) hosts go down. - - go to (a) until required nodes are allocated. - - If a node 'dies', follow same procedure. - - PS: I know the wording here is weird, hopefully it makes some sense ! - */ - def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = { - - val nodeToSplit = new HashMap[String, HashSet[SplitInfo]] - for (inputSplit <- formats) { - val splits = inputSplit.findPreferredLocations() - - for (split <- splits){ - val location = split.hostLocation - val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo]) - set += split - } - } - - nodeToSplit - } -} diff --git a/core/src/main/scala/spark/scheduler/JobListener.scala b/core/src/main/scala/spark/scheduler/JobListener.scala deleted file mode 100644 index af108b8fec..0000000000 --- a/core/src/main/scala/spark/scheduler/JobListener.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -/** - * Interface used to listen for job completion or failure events after submitting a job to the - * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole - * job fails (and no further taskSucceeded events will happen). - */ -private[spark] trait JobListener { - def taskSucceeded(index: Int, result: Any) - def jobFailed(exception: Exception) -} diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala deleted file mode 100644 index 1bc9fabdff..0000000000 --- a/core/src/main/scala/spark/scheduler/JobLogger.scala +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.io.PrintWriter -import java.io.File -import java.io.FileNotFoundException -import java.text.SimpleDateFormat -import java.util.{Date, Properties} -import java.util.concurrent.LinkedBlockingQueue - -import scala.collection.mutable.{Map, HashMap, ListBuffer} -import scala.io.Source - -import spark._ -import spark.executor.TaskMetrics -import spark.scheduler.cluster.TaskInfo - -// Used to record runtime information for each job, including RDD graph -// tasks' start/stop shuffle information and information from outside - -class JobLogger(val logDirName: String) extends SparkListener with Logging { - private val logDir = - if (System.getenv("SPARK_LOG_DIR") != null) - System.getenv("SPARK_LOG_DIR") - else - "/tmp/spark" - private val jobIDToPrintWriter = new HashMap[Int, PrintWriter] - private val stageIDToJobID = new HashMap[Int, Int] - private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]] - private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents] - - createLogDir() - def this() = this(String.valueOf(System.currentTimeMillis())) - - def getLogDir = logDir - def getJobIDtoPrintWriter = jobIDToPrintWriter - def getStageIDToJobID = stageIDToJobID - def getJobIDToStages = jobIDToStages - def getEventQueue = eventQueue - - // Create a folder for log files, the folder's name is the creation time of the jobLogger - protected def createLogDir() { - val dir = new File(logDir + "/" + logDirName + "/") - if (dir.exists()) { - return - } - if (dir.mkdirs() == false) { - logError("create log directory error:" + logDir + "/" + logDirName + "/") - } - } - - // Create a log file for one job, the file name is the jobID - protected def createLogWriter(jobID: Int) { - try{ - val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID) - jobIDToPrintWriter += (jobID -> fileWriter) - } catch { - case e: FileNotFoundException => e.printStackTrace() - } - } - - // Close log file, and clean the stage relationship in stageIDToJobID - protected def closeLogWriter(jobID: Int) = - jobIDToPrintWriter.get(jobID).foreach { fileWriter => - fileWriter.close() - jobIDToStages.get(jobID).foreach(_.foreach{ stage => - stageIDToJobID -= stage.id - }) - jobIDToPrintWriter -= jobID - jobIDToStages -= jobID - } - - // Write log information to log file, withTime parameter controls whether to recored - // time stamp for the information - protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) { - var writeInfo = info - if (withTime) { - val date = new Date(System.currentTimeMillis()) - writeInfo = DATE_FORMAT.format(date) + ": " +info - } - jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo)) - } - - protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) = - stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime)) - - protected def buildJobDep(jobID: Int, stage: Stage) { - if (stage.jobId == jobID) { - jobIDToStages.get(jobID) match { - case Some(stageList) => stageList += stage - case None => val stageList = new ListBuffer[Stage] - stageList += stage - jobIDToStages += (jobID -> stageList) - } - stageIDToJobID += (stage.id -> jobID) - stage.parents.foreach(buildJobDep(jobID, _)) - } - } - - protected def recordStageDep(jobID: Int) { - def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = { - var rddList = new ListBuffer[RDD[_]] - rddList += rdd - rdd.dependencies.foreach{ dep => dep match { - case shufDep: ShuffleDependency[_,_] => - case _ => rddList ++= getRddsInStage(dep.rdd) - } - } - rddList - } - jobIDToStages.get(jobID).foreach {_.foreach { stage => - var depRddDesc: String = "" - getRddsInStage(stage.rdd).foreach { rdd => - depRddDesc += rdd.id + "," - } - var depStageDesc: String = "" - stage.parents.foreach { stage => - depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")" - } - jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" + - depRddDesc.substring(0, depRddDesc.length - 1) + ")" + - " STAGE_DEP=" + depStageDesc, false) - } - } - } - - // Generate indents and convert to String - protected def indentString(indent: Int) = { - val sb = new StringBuilder() - for (i <- 1 to indent) { - sb.append(" ") - } - sb.toString() - } - - protected def getRddName(rdd: RDD[_]) = { - var rddName = rdd.getClass.getName - if (rdd.name != null) { - rddName = rdd.name - } - rddName - } - - protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) { - val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")" - jobLogInfo(jobID, indentString(indent) + rddInfo, false) - rdd.dependencies.foreach{ dep => dep match { - case shufDep: ShuffleDependency[_,_] => - val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId - jobLogInfo(jobID, indentString(indent + 1) + depInfo, false) - case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1) - } - } - } - - protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) { - var stageInfo: String = "" - if (stage.isShuffleMap) { - stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + - stage.shuffleDep.get.shuffleId - }else{ - stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE" - } - if (stage.jobId == jobID) { - jobLogInfo(jobID, indentString(indent) + stageInfo, false) - recordRddInStageGraph(jobID, stage.rdd, indent) - stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2)) - } else - jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false) - } - - // Record task metrics into job log files - protected def recordTaskMetrics(stageID: Int, status: String, - taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID + - " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + - " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname - val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime - val readMetrics = - taskMetrics.shuffleReadMetrics match { - case Some(metrics) => - " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime + - " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + - " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + - " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + - " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead - case None => "" - } - val writeMetrics = - taskMetrics.shuffleWriteMetrics match { - case Some(metrics) => - " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten - case None => "" - } - stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics) - } - - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { - stageLogInfo( - stageSubmitted.stage.id, - "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( - stageSubmitted.stage.id, stageSubmitted.taskSize)) - } - - override def onStageCompleted(stageCompleted: StageCompleted) { - stageLogInfo( - stageCompleted.stageInfo.stage.id, - "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id)) - - } - - override def onTaskStart(taskStart: SparkListenerTaskStart) { } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - val task = taskEnd.task - val taskInfo = taskEnd.taskInfo - var taskStatus = "" - task match { - case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK" - case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK" - } - taskEnd.reason match { - case Success => taskStatus += " STATUS=SUCCESS" - recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics) - case Resubmitted => - taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + - " STAGE_ID=" + task.stageId - stageLogInfo(task.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => - taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + - task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + - mapId + " REDUCE_ID=" + reduceId - stageLogInfo(task.stageId, taskStatus) - case OtherFailure(message) => - taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId + - " STAGE_ID=" + task.stageId + " INFO=" + message - stageLogInfo(task.stageId, taskStatus) - case _ => - } - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd) { - val job = jobEnd.job - var info = "JOB_ID=" + job.jobId - jobEnd.jobResult match { - case JobSucceeded => info += " STATUS=SUCCESS" - case JobFailed(exception, _) => - info += " STATUS=FAILED REASON=" - exception.getMessage.split("\\s+").foreach(info += _ + "_") - case _ => - } - jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase) - closeLogWriter(job.jobId) - } - - protected def recordJobProperties(jobID: Int, properties: Properties) { - if(properties != null) { - val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "") - jobLogInfo(jobID, description, false) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart) { - val job = jobStart.job - val properties = jobStart.properties - createLogWriter(job.jobId) - recordJobProperties(job.jobId, properties) - buildJobDep(job.jobId, job.finalStage) - recordStageDep(job.jobId) - recordStageDepGraph(job.jobId, job.finalStage) - jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED") - } -} diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala deleted file mode 100644 index a61b335152..0000000000 --- a/core/src/main/scala/spark/scheduler/JobResult.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -/** - * A result of a job in the DAGScheduler. - */ -private[spark] sealed trait JobResult - -private[spark] case object JobSucceeded extends JobResult -private[spark] case class JobFailed(exception: Exception, failedStage: Option[Stage]) extends JobResult diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala deleted file mode 100644 index 69cd161c1f..0000000000 --- a/core/src/main/scala/spark/scheduler/JobWaiter.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import scala.collection.mutable.ArrayBuffer - -/** - * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their - * results to the given handler function. - */ -private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit) - extends JobListener { - - private var finishedTasks = 0 - - private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? - private var jobResult: JobResult = null // If the job is finished, this will be its result - - override def taskSucceeded(index: Int, result: Any) { - synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") - } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - jobFinished = true - jobResult = JobSucceeded - this.notifyAll() - } - } - } - - override def jobFailed(exception: Exception) { - synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter") - } - jobFinished = true - jobResult = JobFailed(exception, None) - this.notifyAll() - } - } - - def awaitResult(): JobResult = synchronized { - while (!jobFinished) { - this.wait() - } - return jobResult - } -} diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala deleted file mode 100644 index 2f6a68ee85..0000000000 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.storage.BlockManagerId -import java.io.{ObjectOutput, ObjectInput, Externalizable} - -/** - * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. - * The map output sizes are compressed using MapOutputTracker.compressSize. - */ -private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte]) - extends Externalizable { - - def this() = this(null, null) // For deserialization only - - def writeExternal(out: ObjectOutput) { - location.writeExternal(out) - out.writeInt(compressedSizes.length) - out.write(compressedSizes) - } - - def readExternal(in: ObjectInput) { - location = BlockManagerId(in) - compressedSizes = new Array[Byte](in.readInt()) - in.readFully(compressedSizes) - } -} diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala deleted file mode 100644 index d066df5dc1..0000000000 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark._ -import java.io._ -import util.{MetadataCleaner, TimeStampedHashMap} -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -private[spark] object ResultTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues) - - def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(func) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - return bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { - val loader = Thread.currentThread.getContextClassLoader - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - return (rdd, func) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - - -private[spark] class ResultTask[T, U]( - stageId: Int, - var rdd: RDD[T], - var func: (TaskContext, Iterator[T]) => U, - var partition: Int, - @transient locs: Seq[TaskLocation], - val outputId: Int) - extends Task[U](stageId) with Externalizable { - - def this() = this(0, null, null, 0, null, 0) - - var split = if (rdd == null) { - null - } else { - rdd.partitions(partition) - } - - @transient private val preferredLocs: Seq[TaskLocation] = { - if (locs == null) Nil else locs.toSet.toSeq - } - - override def run(attemptId: Long): U = { - val context = new TaskContext(stageId, partition, attemptId) - metrics = Some(context.taskMetrics) - try { - func(context, rdd.iterator(split, context)) - } finally { - context.executeOnCompleteCallbacks() - } - } - - override def preferredLocations: Seq[TaskLocation] = preferredLocs - - override def toString = "ResultTask(" + stageId + ", " + partition + ")" - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partition) - out.writeInt(stageId) - val bytes = ResultTask.serializeInfo( - stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partition) - out.writeInt(outputId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) - rdd = rdd_.asInstanceOf[RDD[T]] - func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partition = in.readInt() - val outputId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } -} diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala deleted file mode 100644 index f2a038576b..0000000000 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.HashMap - -import spark._ -import spark.executor.ShuffleWriteMetrics -import spark.storage._ -import spark.util.{TimeStampedHashMap, MetadataCleaner} - - -private[spark] object ShuffleMapTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues) - - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(dep) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - return bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { - synchronized { - val loader = Thread.currentThread.getContextClassLoader - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] - return (rdd, dep) - } - } - - // Since both the JarSet and FileSet have the same format this is used for both. - def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) - val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - return (HashMap(set.toSeq: _*)) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - -private[spark] class ShuffleMapTask( - stageId: Int, - var rdd: RDD[_], - var dep: ShuffleDependency[_,_], - var partition: Int, - @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId) - with Externalizable - with Logging { - - protected def this() = this(0, null, null, 0, null) - - @transient private val preferredLocs: Seq[TaskLocation] = { - if (locs == null) Nil else locs.toSet.toSeq - } - - var split = if (rdd == null) null else rdd.partitions(partition) - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partition) - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partition) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) - rdd = rdd_ - dep = dep_ - partition = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } - - override def run(attemptId: Long): MapStatus = { - val numOutputSplits = dep.partitioner.numPartitions - - val taskContext = new TaskContext(stageId, partition, attemptId) - metrics = Some(taskContext.taskMetrics) - - val blockManager = SparkEnv.get.blockManager - var shuffle: ShuffleBlocks = null - var buckets: ShuffleWriterGroup = null - - try { - // Obtain all the block writers for shuffle blocks. - val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) - shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) - buckets = shuffle.acquireWriters(partition) - - // Write the map output to its associated buckets. - for (elem <- rdd.iterator(split, taskContext)) { - val pair = elem.asInstanceOf[Product2[Any, Any]] - val bucketId = dep.partitioner.getPartition(pair._1) - buckets.writers(bucketId).write(pair) - } - - // Commit the writes. Get the size of each bucket block (total block size). - var totalBytes = 0L - val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() - val size = writer.size() - totalBytes += size - MapOutputTracker.compressSize(size) - } - - // Update shuffle metrics. - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - - return new MapStatus(blockManager.blockManagerId, compressedSizes) - } catch { case e: Exception => - // If there is an exception from running the task, revert the partial writes - // and throw the exception upstream to Spark. - if (buckets != null) { - buckets.writers.foreach(_.revertPartialWrites()) - } - throw e - } finally { - // Release the writers back to the shuffle block manager. - if (shuffle != null && buckets != null) { - shuffle.releaseWriters(buckets) - } - // Execute the callbacks on task completion. - taskContext.executeOnCompleteCallbacks() - } - } - - override def preferredLocations: Seq[TaskLocation] = preferredLocs - - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) -} diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala deleted file mode 100644 index e5531011c2..0000000000 --- a/core/src/main/scala/spark/scheduler/SparkListener.scala +++ /dev/null @@ -1,204 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.util.Properties -import spark.scheduler.cluster.TaskInfo -import spark.util.Distribution -import spark.{Logging, SparkContext, TaskEndReason, Utils} -import spark.executor.TaskMetrics - -sealed trait SparkListenerEvents - -case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties) - extends SparkListenerEvents - -case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents - -case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents - -case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, - taskMetrics: TaskMetrics) extends SparkListenerEvents - -case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) - extends SparkListenerEvents - -case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) - extends SparkListenerEvents - -trait SparkListener { - /** - * Called when a stage is completed, with information on the completed stage - */ - def onStageCompleted(stageCompleted: StageCompleted) { } - - /** - * Called when a stage is submitted - */ - def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } - - /** - * Called when a task starts - */ - def onTaskStart(taskEnd: SparkListenerTaskStart) { } - - /** - * Called when a task ends - */ - def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } - - /** - * Called when a job starts - */ - def onJobStart(jobStart: SparkListenerJobStart) { } - - /** - * Called when a job ends - */ - def onJobEnd(jobEnd: SparkListenerJobEnd) { } - -} - -/** - * Simple SparkListener that logs a few summary statistics when each stage completes - */ -class StatsReportListener extends SparkListener with Logging { - override def onStageCompleted(stageCompleted: StageCompleted) { - import spark.scheduler.StatsReportListener._ - implicit val sc = stageCompleted - this.logInfo("Finished stage: " + stageCompleted.stageInfo) - showMillisDistribution("task runtime:", (info, _) => Some(info.duration)) - - //shuffle write - showBytesDistribution("shuffle bytes written:",(_,metric) => metric.shuffleWriteMetrics.map{_.shuffleBytesWritten}) - - //fetch & io - showMillisDistribution("fetch wait time:",(_, metric) => metric.shuffleReadMetrics.map{_.fetchWaitTime}) - showBytesDistribution("remote bytes read:", (_, metric) => metric.shuffleReadMetrics.map{_.remoteBytesRead}) - showBytesDistribution("task result size:", (_, metric) => Some(metric.resultSize)) - - //runtime breakdown - - val runtimePcts = stageCompleted.stageInfo.taskInfos.map{ - case (info, metrics) => RuntimePercentage(info.duration, metrics) - } - showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%") - showDistribution("fetch wait time pct: ", Distribution(runtimePcts.flatMap{_.fetchPct.map{_ * 100}}), "%2.0f %%") - showDistribution("other time pct: ", Distribution(runtimePcts.map{_.other * 100}), "%2.0f %%") - } - -} - -object StatsReportListener extends Logging { - - //for profiling, the extremes are more interesting - val percentiles = Array[Int](0,5,10,25,50,75,90,95,100) - val probabilities = percentiles.map{_ / 100.0} - val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" - - def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = { - Distribution(stage.stageInfo.taskInfos.flatMap{ - case ((info,metric)) => getMetric(info, metric)}) - } - - //is there some way to setup the types that I can get rid of this completely? - def extractLongDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Long]): Option[Distribution] = { - extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble}) - } - - def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { - val stats = d.statCounter - logInfo(heading + stats) - val quantiles = d.getQuantiles(probabilities).map{formatNumber} - logInfo(percentilesHeader) - logInfo("\t" + quantiles.mkString("\t")) - } - - def showDistribution(heading: String, dOpt: Option[Distribution], formatNumber: Double => String) { - dOpt.foreach { d => showDistribution(heading, d, formatNumber)} - } - - def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { - def f(d:Double) = format.format(d) - showDistribution(heading, dOpt, f _) - } - - def showDistribution(heading:String, format: String, getMetric: (TaskInfo,TaskMetrics) => Option[Double]) - (implicit stage: StageCompleted) { - showDistribution(heading, extractDoubleDistribution(stage, getMetric), format) - } - - def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long]) - (implicit stage: StageCompleted) { - showBytesDistribution(heading, extractLongDistribution(stage, getMetric)) - } - - def showBytesDistribution(heading: String, dOpt: Option[Distribution]) { - dOpt.foreach{dist => showBytesDistribution(heading, dist)} - } - - def showBytesDistribution(heading: String, dist: Distribution) { - showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String) - } - - def showMillisDistribution(heading: String, dOpt: Option[Distribution]) { - showDistribution(heading, dOpt, (d => StatsReportListener.millisToString(d.toLong)): Double => String) - } - - def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long]) - (implicit stage: StageCompleted) { - showMillisDistribution(heading, extractLongDistribution(stage, getMetric)) - } - - - - val seconds = 1000L - val minutes = seconds * 60 - val hours = minutes * 60 - - /** - * reformat a time interval in milliseconds to a prettier format for output - */ - def millisToString(ms: Long) = { - val (size, units) = - if (ms > hours) { - (ms.toDouble / hours, "hours") - } else if (ms > minutes) { - (ms.toDouble / minutes, "min") - } else if (ms > seconds) { - (ms.toDouble / seconds, "s") - } else { - (ms.toDouble, "ms") - } - "%.1f %s".format(size, units) - } -} - - - -case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) -object RuntimePercentage { - def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { - val denom = totalTime.toDouble - val fetchTime = metrics.shuffleReadMetrics.map{_.fetchWaitTime} - val fetch = fetchTime.map{_ / denom} - val exec = (metrics.executorRunTime - fetchTime.getOrElse(0l)) / denom - val other = 1.0 - (exec + fetch.getOrElse(0d)) - RuntimePercentage(exec, fetch, other) - } -} diff --git a/core/src/main/scala/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/spark/scheduler/SparkListenerBus.scala deleted file mode 100644 index f55ed455ed..0000000000 --- a/core/src/main/scala/spark/scheduler/SparkListenerBus.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.util.concurrent.LinkedBlockingQueue - -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} - -import spark.Logging - -/** Asynchronously passes SparkListenerEvents to registered SparkListeners. */ -private[spark] class SparkListenerBus() extends Logging { - private val sparkListeners = new ArrayBuffer[SparkListener]() with SynchronizedBuffer[SparkListener] - - /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents](EVENT_QUEUE_CAPACITY) - private var queueFullErrorMessageLogged = false - - new Thread("SparkListenerBus") { - setDaemon(true) - override def run() { - while (true) { - val event = eventQueue.take - event match { - case stageSubmitted: SparkListenerStageSubmitted => - sparkListeners.foreach(_.onStageSubmitted(stageSubmitted)) - case stageCompleted: StageCompleted => - sparkListeners.foreach(_.onStageCompleted(stageCompleted)) - case jobStart: SparkListenerJobStart => - sparkListeners.foreach(_.onJobStart(jobStart)) - case jobEnd: SparkListenerJobEnd => - sparkListeners.foreach(_.onJobEnd(jobEnd)) - case taskStart: SparkListenerTaskStart => - sparkListeners.foreach(_.onTaskStart(taskStart)) - case taskEnd: SparkListenerTaskEnd => - sparkListeners.foreach(_.onTaskEnd(taskEnd)) - case _ => - } - } - } - }.start() - - def addListener(listener: SparkListener) { - sparkListeners += listener - } - - def post(event: SparkListenerEvents) { - val eventAdded = eventQueue.offer(event) - if (!eventAdded && !queueFullErrorMessageLogged) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with the " + - "rate at which tasks are being started by the scheduler.") - queueFullErrorMessageLogged = true - } - } -} - diff --git a/core/src/main/scala/spark/scheduler/SplitInfo.scala b/core/src/main/scala/spark/scheduler/SplitInfo.scala deleted file mode 100644 index 4e3661ec5d..0000000000 --- a/core/src/main/scala/spark/scheduler/SplitInfo.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import collection.mutable.ArrayBuffer - -// information about a specific split instance : handles both split instances. -// So that we do not need to worry about the differences. -class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String, - val length: Long, val underlyingSplit: Any) { - override def toString(): String = { - "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + - ", hostLocation : " + hostLocation + ", path : " + path + - ", length : " + length + ", underlyingSplit " + underlyingSplit - } - - override def hashCode(): Int = { - var hashCode = inputFormatClazz.hashCode - hashCode = hashCode * 31 + hostLocation.hashCode - hashCode = hashCode * 31 + path.hashCode - // ignore overflow ? It is hashcode anyway ! - hashCode = hashCode * 31 + (length & 0x7fffffff).toInt - hashCode - } - - // This is practically useless since most of the Split impl's dont seem to implement equals :-( - // So unless there is identity equality between underlyingSplits, it will always fail even if it - // is pointing to same block. - override def equals(other: Any): Boolean = other match { - case that: SplitInfo => { - this.hostLocation == that.hostLocation && - this.inputFormatClazz == that.inputFormatClazz && - this.path == that.path && - this.length == that.length && - // other split specific checks (like start for FileSplit) - this.underlyingSplit == that.underlyingSplit - } - case _ => false - } -} - -object SplitInfo { - - def toSplitInfo(inputFormatClazz: Class[_], path: String, - mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = { - val retval = new ArrayBuffer[SplitInfo]() - val length = mapredSplit.getLength - for (host <- mapredSplit.getLocations) { - retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit) - } - retval - } - - def toSplitInfo(inputFormatClazz: Class[_], path: String, - mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = { - val retval = new ArrayBuffer[SplitInfo]() - val length = mapreduceSplit.getLength - for (host <- mapreduceSplit.getLocations) { - retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit) - } - retval - } -} diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala deleted file mode 100644 index c599c00ac4..0000000000 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.net.URI - -import spark._ -import spark.storage.BlockManagerId - -/** - * A stage is a set of independent tasks all computing the same function that need to run as part - * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run - * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the - * DAGScheduler runs these stages in topological order. - * - * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for - * another stage, or a result stage, in which case its tasks directly compute the action that - * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes - * that each output partition is on. - * - * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO - * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered - * faster on failure. - */ -private[spark] class Stage( - val id: Int, - val rdd: RDD[_], - val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage - val parents: List[Stage], - val jobId: Int, - callSite: Option[String]) - extends Logging { - - val isShuffleMap = shuffleDep != None - val numPartitions = rdd.partitions.size - val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - var numAvailableOutputs = 0 - - /** When first task was submitted to scheduler. */ - var submissionTime: Option[Long] = None - var completionTime: Option[Long] = None - - private var nextAttemptId = 0 - - def isAvailable: Boolean = { - if (!isShuffleMap) { - true - } else { - numAvailableOutputs == numPartitions - } - } - - def addOutputLoc(partition: Int, status: MapStatus) { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) - numAvailableOutputs += 1 - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 - } - } - - def removeOutputsOnExecutor(execId: String) { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, numAvailableOutputs, numPartitions, isAvailable)) - } - } - - def newAttemptId(): Int = { - val id = nextAttemptId - nextAttemptId += 1 - return id - } - - val name = callSite.getOrElse(rdd.origin) - - override def toString = "Stage " + id - - override def hashCode(): Int = id -} diff --git a/core/src/main/scala/spark/scheduler/StageInfo.scala b/core/src/main/scala/spark/scheduler/StageInfo.scala deleted file mode 100644 index c4026f995a..0000000000 --- a/core/src/main/scala/spark/scheduler/StageInfo.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.scheduler.cluster.TaskInfo -import scala.collection._ -import spark.executor.TaskMetrics - -case class StageInfo( - val stage: Stage, - val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]() -) { - override def toString = stage.rdd.toString -} diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala deleted file mode 100644 index 0ab2ae6cfe..0000000000 --- a/core/src/main/scala/spark/scheduler/Task.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.serializer.SerializerInstance -import java.io.{DataInputStream, DataOutputStream} -import java.nio.ByteBuffer -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.util.ByteBufferInputStream -import scala.collection.mutable.HashMap -import spark.executor.TaskMetrics - -/** - * A task to execute on a worker node. - */ -private[spark] abstract class Task[T](val stageId: Int) extends Serializable { - def run(attemptId: Long): T - def preferredLocations: Seq[TaskLocation] = Nil - - var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler. - - var metrics: Option[TaskMetrics] = None - -} - -/** - * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We - * need to send the list of JARs and files added to the SparkContext with each task to ensure that - * worker nodes find out about it, but we can't make it part of the Task because the user's code in - * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by - * first writing out its dependencies. - */ -private[spark] object Task { - /** - * Serialize a task and the current app dependencies (files and JARs added to the SparkContext) - */ - def serializeWithDependencies( - task: Task[_], - currentFiles: HashMap[String, Long], - currentJars: HashMap[String, Long], - serializer: SerializerInstance) - : ByteBuffer = { - - val out = new FastByteArrayOutputStream(4096) - val dataOut = new DataOutputStream(out) - - // Write currentFiles - dataOut.writeInt(currentFiles.size) - for ((name, timestamp) <- currentFiles) { - dataOut.writeUTF(name) - dataOut.writeLong(timestamp) - } - - // Write currentJars - dataOut.writeInt(currentJars.size) - for ((name, timestamp) <- currentJars) { - dataOut.writeUTF(name) - dataOut.writeLong(timestamp) - } - - // Write the task itself and finish - dataOut.flush() - val taskBytes = serializer.serialize(task).array() - out.write(taskBytes) - out.trim() - ByteBuffer.wrap(out.array) - } - - /** - * Deserialize the list of dependencies in a task serialized with serializeWithDependencies, - * and return the task itself as a serialized ByteBuffer. The caller can then update its - * ClassLoaders and deserialize the task. - * - * @return (taskFiles, taskJars, taskBytes) - */ - def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { - - val in = new ByteBufferInputStream(serializedTask) - val dataIn = new DataInputStream(in) - - // Read task's files - val taskFiles = new HashMap[String, Long]() - val numFiles = dataIn.readInt() - for (i <- 0 until numFiles) { - taskFiles(dataIn.readUTF()) = dataIn.readLong() - } - - // Read task's JARs - val taskJars = new HashMap[String, Long]() - val numJars = dataIn.readInt() - for (i <- 0 until numJars) { - taskJars(dataIn.readUTF()) = dataIn.readLong() - } - - // Create a sub-buffer for the rest of the data, which is the serialized Task object - val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, subBuffer) - } -} diff --git a/core/src/main/scala/spark/scheduler/TaskLocation.scala b/core/src/main/scala/spark/scheduler/TaskLocation.scala deleted file mode 100644 index fea117e956..0000000000 --- a/core/src/main/scala/spark/scheduler/TaskLocation.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -/** - * A location where a task should run. This can either be a host or a (host, executorID) pair. - * In the latter case, we will prefer to launch the task on that executorID, but our next level - * of preference will be executors on the same host if this is not possible. - */ -private[spark] -class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable { - override def toString: String = "TaskLocation(" + host + ", " + executorId + ")" -} - -private[spark] object TaskLocation { - def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId)) - - def apply(host: String) = new TaskLocation(host, None) -} diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala deleted file mode 100644 index fc4856756b..0000000000 --- a/core/src/main/scala/spark/scheduler/TaskResult.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.io._ - -import scala.collection.mutable.Map -import spark.executor.TaskMetrics -import spark.{Utils, SparkEnv} -import java.nio.ByteBuffer - -// Task result. Also contains updates to accumulator variables. -// TODO: Use of distributed cache to return result is a hack to get around -// what seems to be a bug with messages over 60KB in libprocess; fix it -private[spark] -class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics) - extends Externalizable -{ - def this() = this(null.asInstanceOf[T], null, null) - - override def writeExternal(out: ObjectOutput) { - - val objectSer = SparkEnv.get.serializer.newInstance() - val bb = objectSer.serialize(value) - - out.writeInt(bb.remaining()) - Utils.writeByteBuffer(bb, out) - - out.writeInt(accumUpdates.size) - for ((key, value) <- accumUpdates) { - out.writeLong(key) - out.writeObject(value) - } - out.writeObject(metrics) - } - - override def readExternal(in: ObjectInput) { - - val objectSer = SparkEnv.get.serializer.newInstance() - - val blen = in.readInt() - val byteVal = new Array[Byte](blen) - in.readFully(byteVal) - value = objectSer.deserialize(ByteBuffer.wrap(byteVal)) - - val numUpdates = in.readInt - if (numUpdates == 0) { - accumUpdates = null - } else { - accumUpdates = Map() - for (i <- 0 until numUpdates) { - accumUpdates(in.readLong()) = in.readObject() - } - } - metrics = in.readObject().asInstanceOf[TaskMetrics] - } -} diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala deleted file mode 100644 index 4943d58e25..0000000000 --- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.scheduler.cluster.Pool -import spark.scheduler.cluster.SchedulingMode.SchedulingMode -/** - * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler. - * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, - * and are responsible for sending the tasks to the cluster, running them, retrying if there - * are failures, and mitigating stragglers. They return events to the DAGScheduler through - * the TaskSchedulerListener interface. - */ -private[spark] trait TaskScheduler { - - def rootPool: Pool - - def schedulingMode: SchedulingMode - - def start(): Unit - - // Invoked after system has successfully initialized (typically in spark context). - // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc. - def postStartHook() { } - - // Disconnect from the cluster. - def stop(): Unit - - // Submit a sequence of tasks to run. - def submitTasks(taskSet: TaskSet): Unit - - // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. - def setListener(listener: TaskSchedulerListener): Unit - - // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. - def defaultParallelism(): Int -} diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala deleted file mode 100644 index 64be50b2d0..0000000000 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.scheduler.cluster.TaskInfo -import scala.collection.mutable.Map - -import spark.TaskEndReason -import spark.executor.TaskMetrics - -/** - * Interface for getting events back from the TaskScheduler. - */ -private[spark] trait TaskSchedulerListener { - // A task has started. - def taskStarted(task: Task[_], taskInfo: TaskInfo) - - // A task has finished or failed. - def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit - - // A node was added to the cluster. - def executorGained(execId: String, host: String): Unit - - // A node was lost from the cluster. - def executorLost(execId: String): Unit - - // The TaskScheduler wants to abort an entire task set. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit -} diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala deleted file mode 100644 index dc3550dd0b..0000000000 --- a/core/src/main/scala/spark/scheduler/TaskSet.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.util.Properties - -/** - * A set of tasks submitted together to the low-level TaskScheduler, usually representing - * missing partitions of a particular stage. - */ -private[spark] class TaskSet( - val tasks: Array[Task[_]], - val stageId: Int, - val attempt: Int, - val priority: Int, - val properties: Properties) { - val id: String = stageId + "." + attempt - - override def toString: String = "TaskSet " + id -} diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala deleted file mode 100644 index 679d899b47..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ /dev/null @@ -1,440 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.lang.{Boolean => JBoolean} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import spark._ -import spark.TaskState.TaskState -import spark.scheduler._ -import spark.scheduler.cluster.SchedulingMode.SchedulingMode -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong -import java.util.{TimerTask, Timer} - -/** - * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call - * initialize() and start(), then submit task sets through the runTasks method. - * - * This class can work with multiple types of clusters by acting through a SchedulerBackend. - * It handles common logic, like determining a scheduling order across jobs, waking up to launch - * speculative tasks, etc. - * - * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple - * threads, so it needs locks in public API methods to maintain its state. In addition, some - * SchedulerBackends sycnchronize on themselves when they want to send events here, and then - * acquire a lock on us, so we need to make sure that we don't try to lock the backend while - * we are holding a lock on ourselves. - */ -private[spark] class ClusterScheduler(val sc: SparkContext) - extends TaskScheduler - with Logging -{ - // How often to check for speculative tasks - val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong - - // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong - - val activeTaskSets = new HashMap[String, TaskSetManager] - - val taskIdToTaskSetId = new HashMap[Long, String] - val taskIdToExecutorId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] - - @volatile private var hasReceivedTask = false - @volatile private var hasLaunchedTask = false - private val starvationTimer = new Timer(true) - - // Incrementing Mesos task IDs - val nextTaskId = new AtomicLong(0) - - // Which executor IDs we have executors on - val activeExecutorIds = new HashSet[String] - - // The set of executors we have on each host; this is used to compute hostsAlive, which - // in turn is used to decide when we can attain data locality on a given host - private val executorsByHost = new HashMap[String, HashSet[String]] - - private val executorIdToHost = new HashMap[String, String] - - // JAR server, if any JARs were added by the user to the SparkContext - var jarServer: HttpServer = null - - // URIs of JARs to pass to executor - var jarUris: String = "" - - // Listener object to pass upcalls into - var listener: TaskSchedulerListener = null - - var backend: SchedulerBackend = null - - val mapOutputTracker = SparkEnv.get.mapOutputTracker - - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null - // default scheduler is FIFO - val schedulingMode: SchedulingMode = SchedulingMode.withName( - System.getProperty("spark.cluster.schedulingmode", "FIFO")) - - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener - } - - def initialize(context: SchedulerBackend) { - backend = context - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) - schedulableBuilder = { - schedulingMode match { - case SchedulingMode.FIFO => - new FIFOSchedulableBuilder(rootPool) - case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool) - } - } - schedulableBuilder.buildPools() - } - - def newTaskId(): Long = nextTaskId.getAndIncrement() - - override def start() { - backend.start() - - if (System.getProperty("spark.speculation", "false").toBoolean) { - new Thread("ClusterScheduler speculation check") { - setDaemon(true) - - override def run() { - logInfo("Starting speculative execution thread") - while (true) { - try { - Thread.sleep(SPECULATION_INTERVAL) - } catch { - case e: InterruptedException => {} - } - checkSpeculatableTasks() - } - } - }.start() - } - } - - override def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") - this.synchronized { - val manager = new ClusterTaskSetManager(this, taskSet) - activeTaskSets(taskSet.id) = manager - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - - if (!hasReceivedTask) { - starvationTimer.scheduleAtFixedRate(new TimerTask() { - override def run() { - if (!hasLaunchedTask) { - logWarning("Initial job has not accepted any resources; " + - "check your cluster UI to ensure that workers are registered " + - "and have sufficient memory") - } else { - this.cancel() - } - } - }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) - } - hasReceivedTask = true - } - backend.reviveOffers() - } - - def taskSetFinished(manager: TaskSetManager) { - this.synchronized { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds.remove(manager.taskSet.id) - } - } - - /** - * Called by cluster manager to offer resources on slaves. We respond by asking our active task - * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so - * that tasks are balanced across the cluster. - */ - def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { - SparkEnv.set(sc.env) - - // Mark each slave as alive and remember its hostname - for (o <- offers) { - executorIdToHost(o.executorId) = o.host - if (!executorsByHost.contains(o.host)) { - executorsByHost(o.host) = new HashSet[String]() - executorGained(o.executorId, o.host) - } - } - - // Build a list of tasks to assign to each worker - val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray - val sortedTaskSets = rootPool.getSortedTaskSetQueue() - for (taskSet <- sortedTaskSets) { - logDebug("parentName: %s, name: %s, runningTasks: %s".format( - taskSet.parent.name, taskSet.name, taskSet.runningTasks)) - } - - // Take each TaskSet in our scheduling order, and then offer it each node in increasing order - // of locality levels so that it gets a chance to launch local tasks on all of them. - var launchedTask = false - for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { - do { - launchedTask = false - for (i <- 0 until offers.size) { - val execId = offers(i).executorId - val host = offers(i).host - for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id - taskSetTaskIds(taskSet.taskSet.id) += tid - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - availableCpus(i) -= 1 - launchedTask = true - } - } - } while (launchedTask) - } - - if (tasks.size > 0) { - hasLaunchedTask = true - } - return tasks - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - var taskSetToUpdate: Option[TaskSetManager] = None - var failedExecutor: Option[String] = None - var taskFailed = false - synchronized { - try { - if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { - // We lost this entire executor, so remember that it's gone - val execId = taskIdToExecutorId(tid) - if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) - failedExecutor = Some(execId) - } - } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => - if (activeTaskSets.contains(taskSetId)) { - taskSetToUpdate = Some(activeTaskSets(taskSetId)) - } - if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) - if (taskSetTaskIds.contains(taskSetId)) { - taskSetTaskIds(taskSetId) -= tid - } - taskIdToExecutorId.remove(tid) - } - if (state == TaskState.FAILED) { - taskFailed = true - } - case None => - logInfo("Ignoring update from TID " + tid + " because its task set is gone") - } - } catch { - case e: Exception => logError("Exception in statusUpdate", e) - } - } - // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock - if (taskSetToUpdate != None) { - taskSetToUpdate.get.statusUpdate(tid, state, serializedData) - } - if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) - backend.reviveOffers() - } - if (taskFailed) { - // Also revive offers if a task had failed for some reason other than host lost - backend.reviveOffers() - } - } - - def error(message: String) { - synchronized { - if (activeTaskSets.size > 0) { - // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { - try { - manager.error(message) - } catch { - case e: Exception => logError("Exception in error callback", e) - } - } - } else { - // No task sets are active but we still got an error. Just exit since this - // must mean the error is during registration. - // It might be good to do something smarter here in the future. - logError("Exiting due to error from cluster scheduler: " + message) - System.exit(1) - } - } - } - - override def stop() { - if (backend != null) { - backend.stop() - } - if (jarServer != null) { - jarServer.stop() - } - - // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. - // TODO: Do something better ! - Thread.sleep(5000L) - } - - override def defaultParallelism() = backend.defaultParallelism() - - - // Check for speculatable tasks in all our active jobs. - def checkSpeculatableTasks() { - var shouldRevive = false - synchronized { - shouldRevive = rootPool.checkSpeculatableTasks() - } - if (shouldRevive) { - backend.reviveOffers() - } - } - - // Check for pending tasks in all our active jobs. - def hasPendingTasks: Boolean = { - synchronized { - rootPool.hasPendingTasks() - } - } - - def executorLost(executorId: String, reason: ExecutorLossReason) { - var failedExecutor: Option[String] = None - - synchronized { - if (activeExecutorIds.contains(executorId)) { - val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) - removeExecutor(executorId) - failedExecutor = Some(executorId) - } else { - // We may get multiple executorLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) - } - } - // Call listener.executorLost without holding the lock on this to prevent deadlock - if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) - backend.reviveOffers() - } - } - - /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { - activeExecutorIds -= executorId - val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) - execs -= executorId - if (execs.isEmpty) { - executorsByHost -= host - } - executorIdToHost -= executorId - rootPool.executorLost(executorId, host) - } - - def executorGained(execId: String, host: String) { - listener.executorGained(execId, host) - } - - def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { - executorsByHost.get(host).map(_.toSet) - } - - def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { - executorsByHost.contains(host) - } - - def isExecutorAlive(execId: String): Boolean = synchronized { - activeExecutorIds.contains(execId) - } - - // By default, rack is unknown - def getRackForHost(value: String): Option[String] = None -} - - -object ClusterScheduler { - /** - * Used to balance containers across hosts. - * - * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of - * resource offers representing the order in which the offers should be used. The resource - * offers are ordered such that we'll allocate one container on each host before allocating a - * second container on any host, and so on, in order to reduce the damage if a host fails. - * - * For example, given , , , returns - * [o1, o5, o4, 02, o6, o3] - */ - def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { - val _keyList = new ArrayBuffer[K](map.size) - _keyList ++= map.keys - - // order keyList based on population of value in map - val keyList = _keyList.sortWith( - (left, right) => map(left).size > map(right).size - ) - - val retval = new ArrayBuffer[T](keyList.size * 2) - var index = 0 - var found = true - - while (found) { - found = false - for (key <- keyList) { - val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) - assert(containerList != null) - // Get the index'th entry for this host - if present - if (index < containerList.size){ - retval += containerList.apply(index) - found = true - } - } - index += 1 - } - - retval.toList - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala deleted file mode 100644 index a4d6880abb..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ /dev/null @@ -1,712 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.nio.ByteBuffer -import java.util.{Arrays, NoSuchElementException} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min - -import spark.{FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskState, Utils} -import spark.{ExceptionFailure, SparkException, TaskResultTooBigFailure} -import spark.TaskState.TaskState -import spark.scheduler._ -import scala.Some -import spark.FetchFailed -import spark.ExceptionFailure -import spark.TaskResultTooBigFailure -import spark.util.{SystemClock, Clock} - - -/** - * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of - * the status of each task, retries tasks if they fail (up to a limited number of times), and - * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces - * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, - * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished). - * - * THREADING: This class is designed to only be called from code with a lock on the - * ClusterScheduler (e.g. its event handlers). It should not be called from other threads. - */ -private[spark] class ClusterTaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet, - clock: Clock = SystemClock) - extends TaskSetManager - with Logging -{ - // CPUs to request per task - val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt - - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt - - // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble - - // Serializer for closures and tasks. - val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() - - val tasks = taskSet.tasks - val numTasks = tasks.length - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksFinished = 0 - - var weight = 1 - var minShare = 0 - var runningTasks = 0 - var priority = taskSet.priority - var stageId = taskSet.stageId - var name = "TaskSet_"+taskSet.stageId.toString - var parent: Schedulable = null - - // Set of pending tasks for each executor. These collections are actually - // treated as stacks, in which new tasks are added to the end of the - // ArrayBuffer and removed from the end. This makes it faster to detect - // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. - private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] - - // Set of pending tasks for each host. Similar to pendingTasksForExecutor, - // but at host level. - private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // Set of pending tasks for each rack -- similar to the above. - private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] - - // Set containing pending tasks with no locality preferences. - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] - - // Set containing all pending tasks (also used as a stack, as above). - val allPendingTasks = new ArrayBuffer[Int] - - // Tasks that can be speculated. Since these will be a small fraction of total - // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] - - // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] - - // Did the TaskSet fail? - var failed = false - var causeOfFailure = "" - - // How frequently to reprint duplicate exceptions in full, in milliseconds - val EXCEPTION_PRINT_INTERVAL = - System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong - - // Map of recent exceptions (identified by string representation and top stack frame) to - // duplicate count (how many times the same exception has appeared) and time the full exception - // was printed. This should ideally be an LRU map that can drop old exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() - - // Figure out the current map output tracker epoch and set it on all tasks - val epoch = sched.mapOutputTracker.getEpoch - logDebug("Epoch for " + taskSet + ": " + epoch) - for (t <- tasks) { - t.epoch = epoch - } - - // Add all our tasks to the pending lists. We do this in reverse order - // of task index so that tasks with low indices get launched first. - for (i <- (0 until numTasks).reverse) { - addPendingTask(i) - } - - // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling - val myLocalityLevels = computeValidLocalityLevels() - val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level - - // Delay scheduling variables: we keep track of our current locality level and the time we - // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. - // We then move down if we manage to launch a "more local" task. - var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTime() // Time we last launched a task at this level - - /** - * Add a task to all the pending-task lists that it should be on. If readding is set, we are - * re-adding the task so only include it in each list if it's not already there. - */ - private def addPendingTask(index: Int, readding: Boolean = false) { - // Utility method that adds `index` to a list only if readding=false or it's not already there - def addTo(list: ArrayBuffer[Int]) { - if (!readding || !list.contains(index)) { - list += index - } - } - - var hadAliveLocations = false - for (loc <- tasks(index).preferredLocations) { - for (execId <- loc.executorId) { - if (sched.isExecutorAlive(execId)) { - addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) - hadAliveLocations = true - } - } - if (sched.hasExecutorsAliveOnHost(loc.host)) { - addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) - for (rack <- sched.getRackForHost(loc.host)) { - addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - } - hadAliveLocations = true - } - } - - if (!hadAliveLocations) { - // Even though the task might've had preferred locations, all of those hosts or executors - // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. - addTo(pendingTasksWithNoPrefs) - } - - if (!readding) { - allPendingTasks += index // No point scanning this whole list to find the old task there - } - } - - /** - * Return the pending tasks list for a given executor ID, or an empty list if - * there is no map entry for that host - */ - private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = { - pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer()) - } - - /** - * Return the pending tasks list for a given host, or an empty list if - * there is no map entry for that host - */ - private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { - pendingTasksForHost.getOrElse(host, ArrayBuffer()) - } - - /** - * Return the pending rack-local task list for a given rack, or an empty list if - * there is no map entry for that rack - */ - private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = { - pendingTasksForRack.getOrElse(rack, ArrayBuffer()) - } - - /** - * Dequeue a pending task from the given list and return its index. - * Return None if the list is empty. - * This method also cleans up any tasks in the list that have already - * been launched, since we want that to happen lazily. - */ - private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { - while (!list.isEmpty) { - val index = list.last - list.trimEnd(1) - if (copiesRunning(index) == 0 && !finished(index)) { - return Some(index) - } - } - return None - } - - /** Check whether a task is currently running an attempt on a given host */ - private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { - !taskAttempts(taskIndex).exists(_.host == host) - } - - /** - * Return a speculative task for a given executor if any are available. The task should not have - * an attempt running on this host, in case the host is slow. In addition, the task should meet - * the given locality constraint. - */ - private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = - { - speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set - - if (!speculatableTasks.isEmpty) { - // Check for process-local or preference-less tasks; note that tasks can be process-local - // on multiple nodes when we replicate cached blocks, as in Spark Streaming - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val prefs = tasks(index).preferredLocations - val executors = prefs.flatMap(_.executorId) - if (prefs.size == 0 || executors.contains(execId)) { - speculatableTasks -= index - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - } - - // Check for node-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val locations = tasks(index).preferredLocations.map(_.host) - if (locations.contains(host)) { - speculatableTasks -= index - return Some((index, TaskLocality.NODE_LOCAL)) - } - } - } - - // Check for rack-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - for (rack <- sched.getRackForHost(host)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) - if (racks.contains(rack)) { - speculatableTasks -= index - return Some((index, TaskLocality.RACK_LOCAL)) - } - } - } - } - - // Check for non-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - speculatableTasks -= index - return Some((index, TaskLocality.ANY)) - } - } - } - - return None - } - - /** - * Dequeue a pending task for a given node and return its index and locality level. - * Only search for tasks matching the given locality constraint. - */ - private def findTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = - { - for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) { - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- findTaskFromList(getPendingTasksForHost(host))) { - return Some((index, TaskLocality.NODE_LOCAL)) - } - } - - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - for { - rack <- sched.getRackForHost(host) - index <- findTaskFromList(getPendingTasksForRack(rack)) - } { - return Some((index, TaskLocality.RACK_LOCAL)) - } - } - - // Look for no-pref tasks after rack-local tasks since they can run anywhere. - for (index <- findTaskFromList(pendingTasksWithNoPrefs)) { - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- findTaskFromList(allPendingTasks)) { - return Some((index, TaskLocality.ANY)) - } - } - - // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(execId, host, locality) - } - - /** - * Respond to an offer of a single slave from the scheduler by finding a task - */ - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - val curTime = clock.getTime() - - var allowedLocality = getAllowedLocalityLevel(curTime) - if (allowedLocality > maxLocality) { - allowedLocality = maxLocality // We're not allowed to search for farther-away tasks - } - - findTask(execId, host, allowedLocality) match { - case Some((index, taskLocality)) => { - // Found a task; do some bookkeeping and return a task description - val task = tasks(index) - val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, taskLocality)) - // Do various bookkeeping - copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - // Update our locality level for delay scheduling - currentLocalityIndex = getLocalityIndex(taskLocality) - lastLaunchTime = curTime - // Serialize and return the task - val startTime = clock.getTime() - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = clock.getTime() - startTime - increaseRunningTasks(1) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) - if (taskAttempts(index).size == 1) - taskStarted(task,info) - return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) - } - case _ => - } - } - return None - } - - /** - * Get the level we can launch tasks according to delay scheduling, based on current wait time. - */ - private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { - while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && - currentLocalityIndex < myLocalityLevels.length - 1) - { - // Jump to the next locality level, and remove our waiting time for the current one since - // we don't want to count it again on the next one - lastLaunchTime += localityWaits(currentLocalityIndex) - currentLocalityIndex += 1 - } - myLocalityLevels(currentLocalityIndex) - } - - /** - * Find the index in myLocalityLevels for a given locality. This is also designed to work with - * localities that are not in myLocalityLevels (in case we somehow get those) by returning the - * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY. - */ - def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = { - var index = 0 - while (locality > myLocalityLevels(index)) { - index += 1 - } - index - } - - /** Called by cluster scheduler when one of our tasks changes state */ - override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - SparkEnv.set(env) - state match { - case TaskState.FINISHED => - taskFinished(tid, state, serializedData) - case TaskState.LOST => - taskLost(tid, state, serializedData) - case TaskState.FAILED => - taskLost(tid, state, serializedData) - case TaskState.KILLED => - taskLost(tid, state, serializedData) - case _ => - } - } - - def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) - } - - def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markSuccessful() - decreaseRunningTasks(1) - if (!finished(index)) { - tasksFinished += 1 - logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( - tid, info.duration, info.host, tasksFinished, numTasks)) - // Deserialize task result and pass it to the scheduler - try { - val result = ser.deserialize[TaskResult[_]](serializedData) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded( - tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - } catch { - case cnf: ClassNotFoundException => - val loader = Thread.currentThread().getContextClassLoader - throw new SparkException("ClassNotFound with classloader: " + loader, cnf) - case ex => throw ex - } - // Mark finished and stop if we've finished all the tasks - finished(index) = true - if (tasksFinished == numTasks) { - sched.taskSetFinished(this) - } - } else { - logInfo("Ignoring task-finished event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markFailed() - decreaseRunningTasks(1) - if (!finished(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - if (serializedData != null && serializedData.limit() > 0) { - val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) - reason match { - case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) - finished(index) = true - tasksFinished += 1 - sched.taskSetFinished(this) - decreaseRunningTasks(runningTasks) - return - - case taskResultTooBig: TaskResultTooBigFailure => - logInfo("Loss was due to task %s result exceeding Akka frame size; aborting job".format( - tid)) - abort("Task %s result exceeded Akka frame size".format(tid)) - return - - case ef: ExceptionFailure => - sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) - val key = ef.description - val now = clock.getTime() - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { - recentExceptions(key) = (0, now) - (true, 0) - } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) - } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) - } - - case _ => {} - } - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - // Count failed attempts only on FAILED and LOST state (not on KILLED) - if (state == TaskState.FAILED || state == TaskState.LOST) { - numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) - } - } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } - - override def error(message: String) { - // Save the error message - abort("Error: " + message) - } - - def abort(message: String) { - failed = true - causeOfFailure = message - // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.listener.taskSetFailed(taskSet, message) - decreaseRunningTasks(runningTasks) - sched.taskSetFinished(this) - } - - override def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def addSchedulable(schedulable: Schedulable) {} - - override def removeSchedulable(schedulable: Schedulable) {} - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */ - override def executorLost(execId: String, host: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a - // task that used to have locations on only this host might now go to the no-prefs list. Note - // that it's okay if we add a task to the same queue twice (if it had multiple preferred - // locations), because findTaskFromList will skip already-running tasks. - for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) - } - for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) - } - - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage - if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.executorId == execId) { - val index = taskInfos(tid).index - if (finished(index)) { - finished(index) = false - copiesRunning(index) -= 1 - tasksFinished -= 1 - addPendingTask(index) - // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our - // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) - } - } - } - // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - taskLost(tid, TaskState.KILLED, null) - } - } - - /** - * Check for tasks to be speculated and return true if there are any. This is called periodically - * by the ClusterScheduler. - * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. - */ - override def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksFinished == numTasks) { - return false - } - var foundTasks = false - val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt - logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksFinished >= minFinishedForSpeculation) { - val time = clock.getTime() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) - val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) - // TODO: Threshold should also look at standard deviation of task durations and have a lower - // bound based on that. - logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { - val index = info.index - if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && - !speculatableTasks.contains(index)) { - logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) - speculatableTasks += index - foundTasks = true - } - } - } - return foundTasks - } - - override def hasPendingTasks(): Boolean = { - numTasks > 0 && tasksFinished < numTasks - } - - private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = System.getProperty("spark.locality.wait", "3000") - level match { - case TaskLocality.PROCESS_LOCAL => - System.getProperty("spark.locality.wait.process", defaultWait).toLong - case TaskLocality.NODE_LOCAL => - System.getProperty("spark.locality.wait.node", defaultWait).toLong - case TaskLocality.RACK_LOCAL => - System.getProperty("spark.locality.wait.rack", defaultWait).toLong - case TaskLocality.ANY => - 0L - } - } - - /** - * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been - * added to queues using addPendingTask. - */ - private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { - import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} - val levels = new ArrayBuffer[TaskLocality.TaskLocality] - if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { - levels += PROCESS_LOCAL - } - if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { - levels += NODE_LOCAL - } - if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { - levels += RACK_LOCAL - } - levels += ANY - logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) - levels.toArray - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala deleted file mode 100644 index 8825f2dd24..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.executor.ExecutorExitCode - -/** - * Represents an explanation for a executor or whole slave failing or exiting. - */ -private[spark] -class ExecutorLossReason(val message: String) { - override def toString: String = message -} - -private[spark] -case class ExecutorExited(val exitCode: Int) - extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { -} - -private[spark] -case class SlaveLost(_message: String = "Slave lost") - extends ExecutorLossReason(_message) { -} diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala deleted file mode 100644 index 83708f07e1..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/Pool.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import spark.Logging -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - -/** - * An Schedulable entity that represent collection of Pools or TaskSetManagers - */ - -private[spark] class Pool( - val poolName: String, - val schedulingMode: SchedulingMode, - initMinShare: Int, - initWeight: Int) - extends Schedulable - with Logging { - - var schedulableQueue = new ArrayBuffer[Schedulable] - var schedulableNameToSchedulable = new HashMap[String, Schedulable] - - var weight = initWeight - var minShare = initMinShare - var runningTasks = 0 - - var priority = 0 - var stageId = 0 - var name = poolName - var parent:Schedulable = null - - var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { - schedulingMode match { - case SchedulingMode.FAIR => - new FairSchedulingAlgorithm() - case SchedulingMode.FIFO => - new FIFOSchedulingAlgorithm() - } - } - - override def addSchedulable(schedulable: Schedulable) { - schedulableQueue += schedulable - schedulableNameToSchedulable(schedulable.name) = schedulable - schedulable.parent= this - } - - override def removeSchedulable(schedulable: Schedulable) { - schedulableQueue -= schedulable - schedulableNameToSchedulable -= schedulable.name - } - - override def getSchedulableByName(schedulableName: String): Schedulable = { - if (schedulableNameToSchedulable.contains(schedulableName)) { - return schedulableNameToSchedulable(schedulableName) - } - for (schedulable <- schedulableQueue) { - var sched = schedulable.getSchedulableByName(schedulableName) - if (sched != null) { - return sched - } - } - return null - } - - override def executorLost(executorId: String, host: String) { - schedulableQueue.foreach(_.executorLost(executorId, host)) - } - - override def checkSpeculatableTasks(): Boolean = { - var shouldRevive = false - for (schedulable <- schedulableQueue) { - shouldRevive |= schedulable.checkSpeculatableTasks() - } - return shouldRevive - } - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator) - for (schedulable <- sortedSchedulableQueue) { - sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue() - } - return sortedTaskSetQueue - } - - override def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def hasPendingTasks(): Boolean = { - schedulableQueue.exists(_.hasPendingTasks()) - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala deleted file mode 100644 index e77e8e4162..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - -import scala.collection.mutable.ArrayBuffer -/** - * An interface for schedulable entities. - * there are two type of Schedulable entities(Pools and TaskSetManagers) - */ -private[spark] trait Schedulable { - var parent: Schedulable - // child queues - def schedulableQueue: ArrayBuffer[Schedulable] - def schedulingMode: SchedulingMode - def weight: Int - def minShare: Int - def runningTasks: Int - def priority: Int - def stageId: Int - def name: String - - def increaseRunningTasks(taskNum: Int): Unit - def decreaseRunningTasks(taskNum: Int): Unit - def addSchedulable(schedulable: Schedulable): Unit - def removeSchedulable(schedulable: Schedulable): Unit - def getSchedulableByName(name: String): Schedulable - def executorLost(executorId: String, host: String): Unit - def checkSpeculatableTasks(): Boolean - def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] - def hasPendingTasks(): Boolean -} diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala deleted file mode 100644 index 2fc8a76a05..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.io.{File, FileInputStream, FileOutputStream, FileNotFoundException} -import java.util.Properties - -import scala.xml.XML - -import spark.Logging -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - - -/** - * An interface to build Schedulable tree - * buildPools: build the tree nodes(pools) - * addTaskSetManager: build the leaf nodes(TaskSetManagers) - */ -private[spark] trait SchedulableBuilder { - def buildPools() - def addTaskSetManager(manager: Schedulable, properties: Properties) -} - -private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) - extends SchedulableBuilder with Logging { - - override def buildPools() { - // nothing - } - - override def addTaskSetManager(manager: Schedulable, properties: Properties) { - rootPool.addSchedulable(manager) - } -} - -private[spark] class FairSchedulableBuilder(val rootPool: Pool) - extends SchedulableBuilder with Logging { - - val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file") - val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.cluster.fair.pool" - val DEFAULT_POOL_NAME = "default" - val MINIMUM_SHARES_PROPERTY = "minShare" - val SCHEDULING_MODE_PROPERTY = "schedulingMode" - val WEIGHT_PROPERTY = "weight" - val POOL_NAME_PROPERTY = "@name" - val POOLS_PROPERTY = "pool" - val DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO - val DEFAULT_MINIMUM_SHARE = 2 - val DEFAULT_WEIGHT = 1 - - override def buildPools() { - if (schedulerAllocFile != null) { - val file = new File(schedulerAllocFile) - if (file.exists()) { - val xml = XML.loadFile(file) - for (poolNode <- (xml \\ POOLS_PROPERTY)) { - - val poolName = (poolNode \ POOL_NAME_PROPERTY).text - var schedulingMode = DEFAULT_SCHEDULING_MODE - var minShare = DEFAULT_MINIMUM_SHARE - var weight = DEFAULT_WEIGHT - - val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text - if (xmlSchedulingMode != "") { - try { - schedulingMode = SchedulingMode.withName(xmlSchedulingMode) - } catch { - case e: Exception => logInfo("Error xml schedulingMode, using default schedulingMode") - } - } - - val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text - if (xmlMinShare != "") { - minShare = xmlMinShare.toInt - } - - val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text - if (xmlWeight != "") { - weight = xmlWeight.toInt - } - - val pool = new Pool(poolName, schedulingMode, minShare, weight) - rootPool.addSchedulable(pool) - logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( - poolName, schedulingMode, minShare, weight)) - } - } else { - throw new java.io.FileNotFoundException( - "Fair scheduler allocation file not found: " + schedulerAllocFile) - } - } - - // finally create "default" pool - if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) { - val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, - DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) - rootPool.addSchedulable(pool) - logInfo("Created default pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( - DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) - } - } - - override def addTaskSetManager(manager: Schedulable, properties: Properties) { - var poolName = DEFAULT_POOL_NAME - var parentPool = rootPool.getSchedulableByName(poolName) - if (properties != null) { - poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) - parentPool = rootPool.getSchedulableByName(poolName) - if (parentPool == null) { - // we will create a new pool that user has configured in app - // instead of being defined in xml file - parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, - DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) - rootPool.addSchedulable(parentPool) - logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( - poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) - } - } - parentPool.addSchedulable(manager) - logInfo("Added task set " + manager.name + " tasks to pool "+poolName) - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala deleted file mode 100644 index 4431744ec3..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.{SparkContext, Utils} - -/** - * A backend interface for cluster scheduling systems that allows plugging in different ones under - * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as - * machines become available and can launch tasks on them. - */ -private[spark] trait SchedulerBackend { - def start(): Unit - def stop(): Unit - def reviveOffers(): Unit - def defaultParallelism(): Int - - // Memory used by each executor (in megabytes) - protected val executorMemory: Int = SparkContext.executorMemoryRequested - - // TODO: Probably want to add a killTask too -} diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala deleted file mode 100644 index 69e0ac2a6b..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -/** - * An interface for sort algorithm - * FIFO: FIFO algorithm between TaskSetManagers - * FS: FS algorithm between Pools, and FIFO or FS within Pools - */ -private[spark] trait SchedulingAlgorithm { - def comparator(s1: Schedulable, s2: Schedulable): Boolean -} - -private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { - override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { - val priority1 = s1.priority - val priority2 = s2.priority - var res = math.signum(priority1 - priority2) - if (res == 0) { - val stageId1 = s1.stageId - val stageId2 = s2.stageId - res = math.signum(stageId1 - stageId2) - } - if (res < 0) { - return true - } else { - return false - } - } -} - -private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { - override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { - val minShare1 = s1.minShare - val minShare2 = s2.minShare - val runningTasks1 = s1.runningTasks - val runningTasks2 = s2.runningTasks - val s1Needy = runningTasks1 < minShare1 - val s2Needy = runningTasks2 < minShare2 - val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble - val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble - val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble - val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var res:Boolean = true - var compare:Int = 0 - - if (s1Needy && !s2Needy) { - return true - } else if (!s1Needy && s2Needy) { - return false - } else if (s1Needy && s2Needy) { - compare = minShareRatio1.compareTo(minShareRatio2) - } else { - compare = taskToWeightRatio1.compareTo(taskToWeightRatio2) - } - - if (compare < 0) { - return true - } else if (compare > 0) { - return false - } else { - return s1.name < s2.name - } - } -} - diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala deleted file mode 100644 index 55cdf4791f..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -/** - * "FAIR" and "FIFO" determines which policy is used - * to order tasks amongst a Schedulable's sub-queues - * "NONE" is used when the a Schedulable has no sub-queues. - */ -object SchedulingMode extends Enumeration("FAIR", "FIFO", "NONE") { - - type SchedulingMode = Value - val FAIR,FIFO,NONE = Value -} diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala deleted file mode 100644 index 7ac574bdc8..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.{Utils, Logging, SparkContext} -import spark.deploy.client.{Client, ClientListener} -import spark.deploy.{Command, ApplicationDescription} -import scala.collection.mutable.HashMap - -private[spark] class SparkDeploySchedulerBackend( - scheduler: ClusterScheduler, - sc: SparkContext, - master: String, - appName: String) - extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) - with ClientListener - with Logging { - - var client: Client = null - var stopping = false - var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ - - val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - - override def start() { - super.start() - - // The endpoint for executors to talk to us - val driverUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") - val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val sparkHome = sc.getSparkHome().getOrElse(null) - val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, - sc.ui.appUIAddress) - - client = new Client(sc.env.actorSystem, master, appDesc, this) - client.start() - } - - override def stop() { - stopping = true - super.stop() - client.stop() - if (shutdownCallback != null) { - shutdownCallback(this) - } - } - - override def connected(appId: String) { - logInfo("Connected to Spark cluster with app ID " + appId) - } - - override def disconnected() { - if (!stopping) { - logError("Disconnected from Spark cluster!") - scheduler.error("Disconnected from Spark cluster") - } - } - - override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { - logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( - executorId, hostPort, cores, Utils.megabytesToString(memory))) - } - - override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) { - val reason: ExecutorLossReason = exitStatus match { - case Some(code) => ExecutorExited(code) - case None => SlaveLost(message) - } - logInfo("Executor %s removed: %s".format(executorId, message)) - removeExecutor(executorId, reason.toString) - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala deleted file mode 100644 index 05c29eb72f..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.nio.ByteBuffer - -import spark.TaskState.TaskState -import spark.Utils -import spark.util.SerializableBuffer - - -private[spark] sealed trait StandaloneClusterMessage extends Serializable - -private[spark] object StandaloneClusterMessages { - - // Driver to executors - case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage - - case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) - extends StandaloneClusterMessage - - case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage - - // Executors to driver - case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) - extends StandaloneClusterMessage { - Utils.checkHostPort(hostPort, "Expected host port") - } - - case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, - data: SerializableBuffer) extends StandaloneClusterMessage - - object StatusUpdate { - /** Alternate factory method that takes a ByteBuffer directly for the data field */ - def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer) - : StatusUpdate = { - StatusUpdate(executorId, taskId, state, new SerializableBuffer(data)) - } - } - - // Internal messages in driver - case object ReviveOffers extends StandaloneClusterMessage - - case object StopDriver extends StandaloneClusterMessage - - case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage - -} diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala deleted file mode 100644 index 3203be1029..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} - -import akka.actor._ -import akka.dispatch.Await -import akka.pattern.ask -import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} -import akka.util.Duration -import akka.util.duration._ - -import spark.{Utils, SparkException, Logging, TaskState} -import spark.scheduler.cluster.StandaloneClusterMessages._ - -/** - * A standalone scheduler backend, which waits for standalone executors to connect to it through - * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained - * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*). - */ -private[spark] -class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) - extends SchedulerBackend with Logging -{ - // Use an atomic variable to track total number of cores in the cluster for simplicity and speed - var totalCoreCount = new AtomicInteger(0) - - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { - private val executorActor = new HashMap[String, ActorRef] - private val executorAddress = new HashMap[String, Address] - private val executorHost = new HashMap[String, String] - private val freeCores = new HashMap[String, Int] - private val actorToExecutorId = new HashMap[ActorRef, String] - private val addressToExecutorId = new HashMap[Address, String] - - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - - // Periodically revive offers to allow delay scheduling to work - val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) - } - - def receive = { - case RegisterExecutor(executorId, hostPort, cores) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorActor.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) - } else { - logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor(sparkProperties) - context.watch(sender) - executorActor(executorId) = sender - executorHost(executorId) = Utils.parseHostPort(hostPort)._1 - freeCores(executorId) = cores - executorAddress(executorId) = sender.path.address - actorToExecutorId(sender) = executorId - addressToExecutorId(sender.path.address) = executorId - totalCoreCount.addAndGet(cores) - makeOffers() - } - - case StatusUpdate(executorId, taskId, state, data) => - scheduler.statusUpdate(taskId, state, data.value) - if (TaskState.isFinished(state)) { - freeCores(executorId) += 1 - makeOffers(executorId) - } - - case ReviveOffers => - makeOffers() - - case StopDriver => - sender ! true - context.stop(self) - - case RemoveExecutor(executorId, reason) => - removeExecutor(executorId, reason) - sender ! true - - case Terminated(actor) => - actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated")) - - case RemoteClientDisconnected(transport, address) => - addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected")) - - case RemoteClientShutdown(transport, address) => - addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown")) - } - - // Make fake resource offers on all executors - def makeOffers() { - launchTasks(scheduler.resourceOffers( - executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) - } - - // Make fake resource offers on just one executor - def makeOffers(executorId: String) { - launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) - } - - // Launch tasks returned by a set of resource offers - def launchTasks(tasks: Seq[Seq[TaskDescription]]) { - for (task <- tasks.flatten) { - freeCores(task.executorId) -= 1 - executorActor(task.executorId) ! LaunchTask(task) - } - } - - // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: String) { - if (executorActor.contains(executorId)) { - logInfo("Executor " + executorId + " disconnected, so removing it") - val numCores = freeCores(executorId) - actorToExecutorId -= executorActor(executorId) - addressToExecutorId -= executorAddress(executorId) - executorActor -= executorId - executorHost -= executorId - freeCores -= executorId - totalCoreCount.addAndGet(-numCores) - scheduler.executorLost(executorId, SlaveLost(reason)) - } - } - } - - var driverActor: ActorRef = null - val taskIdsOnSlave = new HashMap[String, HashSet[String]] - - override def start() { - val properties = new ArrayBuffer[(String, String)] - val iterator = System.getProperties.entrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { - properties += ((key, value)) - } - } - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) - } - - private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - - override def stop() { - try { - if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.result(future, timeout) - } - } catch { - case e: Exception => - throw new SparkException("Error stopping standalone scheduler's driver actor", e) - } - } - - override def reviveOffers() { - driverActor ! ReviveOffers - } - - override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) - .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) - - // Called by subclasses when notified of a lost worker - def removeExecutor(executorId: String, reason: String) { - try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver actor", e) - } - } -} - -private[spark] object StandaloneSchedulerBackend { - val ACTOR_NAME = "StandaloneScheduler" -} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala deleted file mode 100644 index 187553233f..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.nio.ByteBuffer -import spark.util.SerializableBuffer - -private[spark] class TaskDescription( - val taskId: Long, - val executorId: String, - val name: String, - val index: Int, // Index within this task's TaskSet - _serializedTask: ByteBuffer) - extends Serializable { - - // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer - private val buffer = new SerializableBuffer(_serializedTask) - - def serializedTask: ByteBuffer = buffer.value - - override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) -} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala deleted file mode 100644 index c2c5522686..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.Utils - -/** - * Information about a running task attempt inside a TaskSet. - */ -private[spark] -class TaskInfo( - val taskId: Long, - val index: Int, - val launchTime: Long, - val executorId: String, - val host: String, - val taskLocality: TaskLocality.TaskLocality) { - - var finishTime: Long = 0 - var failed = false - - def markSuccessful(time: Long = System.currentTimeMillis) { - finishTime = time - } - - def markFailed(time: Long = System.currentTimeMillis) { - finishTime = time - failed = true - } - - def finished: Boolean = finishTime != 0 - - def successful: Boolean = finished && !failed - - def running: Boolean = !finished - - def status: String = { - if (running) - "RUNNING" - else if (failed) - "FAILED" - else if (successful) - "SUCCESS" - else - "UNKNOWN" - } - - def duration: Long = { - if (!finished) { - throw new UnsupportedOperationException("duration() called on unfinished tasks") - } else { - finishTime - launchTime - } - } - - def timeRunning(currentTime: Long): Long = currentTime - launchTime -} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala b/core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala deleted file mode 100644 index 1c33e41f87..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - - -private[spark] object TaskLocality - extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") -{ - // process local is expected to be used ONLY within tasksetmanager for now. - val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value - - type TaskLocality = Value - - def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { - condition <= constraint - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala deleted file mode 100644 index 0248830b7a..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.nio.ByteBuffer - -import spark.TaskState.TaskState -import spark.scheduler.TaskSet - -/** - * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of - * each task and is responsible for retries on failure and locality. The main interfaces to it - * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and - * statusUpdate, which tells it that one of its tasks changed state (e.g. finished). - * - * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler - * (e.g. its event handlers). It should not be called from other threads. - */ -private[spark] trait TaskSetManager extends Schedulable { - def schedulableQueue = null - - def schedulingMode = SchedulingMode.NONE - - def taskSet: TaskSet - - def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) - - def error(message: String) -} diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala deleted file mode 100644 index 1d09bd9b03..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -/** - * Represents free resources available on an executor. - */ -private[spark] -class WorkerOffer(val executorId: String, val host: String, val cores: Int) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala deleted file mode 100644 index 5be4dbd9f0..0000000000 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.local - -import java.io.File -import java.lang.management.ManagementFactory -import java.util.concurrent.atomic.AtomicInteger -import java.nio.ByteBuffer - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import spark._ -import spark.TaskState.TaskState -import spark.executor.ExecutorURLClassLoader -import spark.scheduler._ -import spark.scheduler.cluster._ -import spark.scheduler.cluster.SchedulingMode.SchedulingMode -import akka.actor._ - -/** - * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally - * the scheduler also allows each task to fail up to maxFailures times, which is useful for - * testing fault recovery. - */ - -private[spark] -case class LocalReviveOffers() - -private[spark] -case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) - -private[spark] -class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { - - def receive = { - case LocalReviveOffers => - launchTask(localScheduler.resourceOffer(freeCores)) - case LocalStatusUpdate(taskId, state, serializeData) => - freeCores += 1 - localScheduler.statusUpdate(taskId, state, serializeData) - launchTask(localScheduler.resourceOffer(freeCores)) - } - - def launchTask(tasks : Seq[TaskDescription]) { - for (task <- tasks) { - freeCores -= 1 - localScheduler.threadPool.submit(new Runnable { - def run() { - localScheduler.runTask(task.taskId, task.serializedTask) - } - }) - } - } -} - -private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) - extends TaskScheduler - with Logging { - - var attemptId = new AtomicInteger(0) - var threadPool = Utils.newDaemonFixedThreadPool(threads) - val env = SparkEnv.get - var listener: TaskSchedulerListener = null - - // Application dependencies (added through SparkContext) that we've fetched so far on this node. - // Each map holds the master's timestamp for the version of that file or JAR we got. - val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() - val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - - val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) - - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null - val schedulingMode: SchedulingMode = SchedulingMode.withName( - System.getProperty("spark.cluster.schedulingmode", "FIFO")) - val activeTaskSets = new HashMap[String, TaskSetManager] - val taskIdToTaskSetId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] - - var localActor: ActorRef = null - - override def start() { - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) - schedulableBuilder = { - schedulingMode match { - case SchedulingMode.FIFO => - new FIFOSchedulableBuilder(rootPool) - case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool) - } - } - schedulableBuilder.buildPools() - - localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") - } - - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener - } - - override def submitTasks(taskSet: TaskSet) { - synchronized { - val manager = new LocalTaskSetManager(this, taskSet) - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - activeTaskSets(taskSet.id) = manager - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - localActor ! LocalReviveOffers - } - } - - def resourceOffer(freeCores: Int): Seq[TaskDescription] = { - synchronized { - var freeCpuCores = freeCores - val tasks = new ArrayBuffer[TaskDescription](freeCores) - val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() - for (manager <- sortedTaskSetQueue) { - logDebug("parentName:%s,name:%s,runningTasks:%s".format( - manager.parent.name, manager.name, manager.runningTasks)) - } - - var launchTask = false - for (manager <- sortedTaskSetQueue) { - do { - launchTask = false - manager.resourceOffer(null, null, freeCpuCores, null) match { - case Some(task) => - tasks += task - taskIdToTaskSetId(task.taskId) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += task.taskId - freeCpuCores -= 1 - launchTask = true - case None => {} - } - } while(launchTask) - } - return tasks - } - } - - def taskSetFinished(manager: TaskSetManager) { - synchronized { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds -= manager.taskSet.id - } - } - - def runTask(taskId: Long, bytes: ByteBuffer) { - logInfo("Running " + taskId) - val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - // Set the Spark execution environment for the worker thread - SparkEnv.set(env) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objectSer = SparkEnv.get.serializer.newInstance() - var attemptedTask: Option[Task[_]] = None - val start = System.currentTimeMillis() - var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum - val startGCTime = getTotalGCTime - - try { - Accumulators.clear() - Thread.currentThread().setContextClassLoader(classLoader) - - // Serialize and deserialize the task so that accumulators are changed to thread-local ones; - // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) - updateDependencies(taskFiles, taskJars) // Download any files added with addFile - val deserializedTask = ser.deserialize[Task[_]]( - taskBytes, Thread.currentThread.getContextClassLoader) - attemptedTask = Some(deserializedTask) - val deserTime = System.currentTimeMillis() - start - taskStart = System.currentTimeMillis() - - // Run it - val result: Any = deserializedTask.run(taskId) - - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val serResult = objectSer.serialize(result) - deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = objectSer.deserialize[Any](serResult) - val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( - ser.serialize(Accumulators.values)) - val serviceTime = System.currentTimeMillis() - taskStart - logInfo("Finished " + taskId) - deserializedTask.metrics.get.executorRunTime = serviceTime.toInt - deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime - deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt - val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null)) - val serializedResult = ser.serialize(taskResult) - localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) - } catch { - case t: Throwable => { - val serviceTime = System.currentTimeMillis() - taskStart - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.executorRunTime = serviceTime.toInt - m.jvmGCTime = getTotalGCTime - startGCTime - } - val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) - } - } - } - - /** - * Download any missing dependencies if we receive a new set of files and JARs from the - * SparkContext. Also adds any new JARs we fetched to the class loader. - */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - synchronized { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!classLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - classLoader.addURL(url) - } - } - } - } - - def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { - synchronized { - val taskSetId = taskIdToTaskSetId(taskId) - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId - taskSetManager.statusUpdate(taskId, state, serializedData) - } - } - - override def stop() { - threadPool.shutdownNow() - } - - override def defaultParallelism() = threads -} diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala deleted file mode 100644 index e237f289e3..0000000000 --- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.local - -import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import spark.{ExceptionFailure, Logging, SparkEnv, Success, TaskState} -import spark.TaskState.TaskState -import spark.scheduler.{Task, TaskResult, TaskSet} -import spark.scheduler.cluster.{Schedulable, TaskDescription, TaskInfo, TaskLocality, TaskSetManager} - - -private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) - extends TaskSetManager with Logging { - - var parent: Schedulable = null - var weight: Int = 1 - var minShare: Int = 0 - var runningTasks: Int = 0 - var priority: Int = taskSet.priority - var stageId: Int = taskSet.stageId - var name: String = "TaskSet_" + taskSet.stageId.toString - - var failCount = new Array[Int](taskSet.tasks.size) - val taskInfos = new HashMap[Long, TaskInfo] - val numTasks = taskSet.tasks.size - var numFinished = 0 - val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val MAX_TASK_FAILURES = sched.maxFailures - - override def increaseRunningTasks(taskNum: Int): Unit = { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int): Unit = { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def addSchedulable(schedulable: Schedulable): Unit = { - // nothing - } - - override def removeSchedulable(schedulable: Schedulable): Unit = { - // nothing - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def executorLost(executorId: String, host: String): Unit = { - // nothing - } - - override def checkSpeculatableTasks() = true - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - override def hasPendingTasks() = true - - def findTask(): Option[Int] = { - for (i <- 0 to numTasks-1) { - if (copiesRunning(i) == 0 && !finished(i)) { - return Some(i) - } - } - return None - } - - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - SparkEnv.set(sched.env) - logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format( - availableCpus.toInt, numFinished, numTasks)) - if (availableCpus > 0 && numFinished < numTasks) { - findTask() match { - case Some(index) => - val taskId = sched.attemptId.getAndIncrement() - val task = taskSet.tasks(index) - val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", - TaskLocality.NODE_LOCAL) - taskInfos(taskId) = info - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val bytes = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") - val taskName = "task %s:%d".format(taskSet.id, index) - copiesRunning(index) += 1 - increaseRunningTasks(1) - taskStarted(task, info) - return Some(new TaskDescription(taskId, null, taskName, index, bytes)) - case None => {} - } - } - return None - } - - override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - SparkEnv.set(env) - state match { - case TaskState.FINISHED => - taskEnded(tid, state, serializedData) - case TaskState.FAILED => - taskFailed(tid, state, serializedData) - case _ => {} - } - } - - def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) - } - - def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markSuccessful() - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) - numFinished += 1 - decreaseRunningTasks(1) - finished(index) = true - if (numFinished == numTasks) { - sched.taskSetFinished(this) - } - } - - def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markFailed() - decreaseRunningTasks(1) - val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( - serializedData, getClass.getClassLoader) - sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) - if (!finished(index)) { - copiesRunning(index) -= 1 - numFailures(index) += 1 - val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - reason.className, reason.description, locs.mkString("\n"))) - if (numFailures(index) > MAX_TASK_FAILURES) { - val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( - taskSet.id, index, 4, reason.description) - decreaseRunningTasks(runningTasks) - sched.listener.taskSetFailed(taskSet, errorMessage) - // need to delete failed Taskset from schedule queue - sched.taskSetFinished(this) - } - } - } - - override def error(message: String) { - } -} diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala deleted file mode 100644 index eef3ee1425..0000000000 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ /dev/null @@ -1,284 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.mesos - -import com.google.protobuf.ByteString - -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} - -import spark.{SparkException, Utils, Logging, SparkContext} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.collection.JavaConversions._ -import java.io.File -import spark.scheduler.cluster._ -import java.util.{ArrayList => JArrayList, List => JList} -import java.util.Collections -import spark.TaskState - -/** - * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds - * onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever - * a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the - * StandaloneBackend mechanism. This class is useful for lower and more predictable latency. - * - * Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to - * remove this. - */ -private[spark] class CoarseMesosSchedulerBackend( - scheduler: ClusterScheduler, - sc: SparkContext, - master: String, - appName: String) - extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) - with MScheduler - with Logging { - - val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures - - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - - // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - - // Cores we have acquired with each Mesos task ID - val coresByTaskId = new HashMap[Int, Int] - var totalCoresAcquired = 0 - - val slaveIdsWithExecutors = new HashSet[String] - - val taskIdToSlaveId = new HashMap[Int, String] - val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed - - val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( - "Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor")) - - val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt - - var nextMesosTaskId = 0 - - def newMesosTaskId(): Int = { - val id = nextMesosTaskId - nextMesosTaskId += 1 - id - } - - override def start() { - super.start() - - synchronized { - new Thread("CoarseMesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } - } - - def createCommand(offer: Offer, numCores: Int): CommandInfo = { - val environment = Environment.newBuilder() - sc.executorEnvs.foreach { case (key, value) => - environment.addVariables(Environment.Variable.newBuilder() - .setName(key) - .setValue(value) - .build()) - } - val command = CommandInfo.newBuilder() - .setEnvironment(environment) - val driverUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.driver.host"), - System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) - val uri = System.getProperty("spark.executor.uri") - if (uri == null) { - val runScript = new File(sparkHome, "spark-class").getCanonicalPath - command.setValue("\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format( - runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) - } else { - // Grab everything to the first '.'. We'll use that and '*' to - // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head - command.setValue("cd %s*; ./spark-class spark.executor.StandaloneExecutorBackend %s %s %s %d".format( - basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) - } - return command.build() - } - - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} - - override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - logInfo("Registered as framework ID " + frameworkId.getValue) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } - } - } - - override def disconnected(d: SchedulerDriver) {} - - override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} - - /** - * Method called by Mesos to offer resources on slaves. We respond by launching an executor, - * unless we've already launched more than we wanted to. - */ - override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { - val filters = Filters.newBuilder().setRefuseSeconds(-1).build() - - for (offer <- offers) { - val slaveId = offer.getSlaveId.toString - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && mem >= executorMemory && cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { - // Launch an executor on the slave - val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) - val taskId = newMesosTaskId() - taskIdToSlaveId(taskId) = slaveId - slaveIdsWithExecutors += slaveId - coresByTaskId(taskId) = cpusToUse - val task = MesosTaskInfo.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) - .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) - .setName("Task " + taskId) - .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", executorMemory)) - .build() - d.launchTasks(offer.getId, Collections.singletonList(task), filters) - } else { - // Filter it out - d.launchTasks(offer.getId, Collections.emptyList[MesosTaskInfo](), filters) - } - } - } - } - - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - private def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) - } - - /** Build a Mesos resource protobuf object */ - private def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } - - /** Check whether a Mesos task state represents a finished task */ - private def isFinished(state: MesosTaskState) = { - state == MesosTaskState.TASK_FINISHED || - state == MesosTaskState.TASK_FAILED || - state == MesosTaskState.TASK_KILLED || - state == MesosTaskState.TASK_LOST - } - - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val taskId = status.getTaskId.getValue.toInt - val state = status.getState - logInfo("Mesos task " + taskId + " is now " + state) - synchronized { - if (isFinished(state)) { - val slaveId = taskIdToSlaveId(taskId) - slaveIdsWithExecutors -= slaveId - taskIdToSlaveId -= taskId - // Remove the cores we have remembered for this task, if it's in the hashmap - for (cores <- coresByTaskId.get(taskId)) { - totalCoresAcquired -= cores - coresByTaskId -= taskId - } - // If it was a failure, mark the slave as failed for blacklisting purposes - if (state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_LOST) { - failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 - if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { - logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " + - "is Spark installed on it?") - } - } - driver.reviveOffers() // In case we'd rejected everything before but have now lost a node - } - } - } - - override def error(d: SchedulerDriver, message: String) { - logError("Mesos error: " + message) - scheduler.error(message) - } - - override def stop() { - super.stop() - if (driver != null) { - driver.stop() - } - } - - override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - if (slaveIdsWithExecutors.contains(slaveId.getValue)) { - // Note that the slave ID corresponds to the executor ID on that slave - slaveIdsWithExecutors -= slaveId.getValue - removeExecutor(slaveId.getValue, "Mesos slave lost") - } - } - } - - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { - logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) - slaveLost(d, s) - } -} diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala deleted file mode 100644 index f6069a5775..0000000000 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ /dev/null @@ -1,342 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.mesos - -import com.google.protobuf.ByteString - -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} - -import spark.{SparkException, Utils, Logging, SparkContext} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.collection.JavaConversions._ -import java.io.File -import spark.scheduler.cluster._ -import java.util.{ArrayList => JArrayList, List => JList} -import java.util.Collections -import spark.TaskState - -/** - * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a - * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks - * from multiple apps can run on different cores) and in time (a core can switch ownership). - */ -private[spark] class MesosSchedulerBackend( - scheduler: ClusterScheduler, - sc: SparkContext, - master: String, - appName: String) - extends SchedulerBackend - with MScheduler - with Logging { - - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - - // Which slave IDs we have executors on - val slaveIdsWithExecutors = new HashSet[String] - val taskIdToSlaveId = new HashMap[Long, String] - - // An ExecutorInfo for our tasks - var execArgs: Array[Byte] = null - - var classLoader: ClassLoader = null - - override def start() { - synchronized { - classLoader = Thread.currentThread.getContextClassLoader - - new Thread("MesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } - } - - def createExecutorInfo(execId: String): ExecutorInfo = { - val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( - "Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor")) - val environment = Environment.newBuilder() - sc.executorEnvs.foreach { case (key, value) => - environment.addVariables(Environment.Variable.newBuilder() - .setName(key) - .setValue(value) - .build()) - } - val command = CommandInfo.newBuilder() - .setEnvironment(environment) - val uri = System.getProperty("spark.executor.uri") - if (uri == null) { - command.setValue(new File(sparkHome, "spark-executor").getCanonicalPath) - } else { - // Grab everything to the first '.'. We'll use that and '*' to - // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head - command.setValue("cd %s*; ./spark-executor".format(basename)) - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) - } - val memory = Resource.newBuilder() - .setName("mem") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build()) - .build() - ExecutorInfo.newBuilder() - .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) - .setCommand(command) - .setData(ByteString.copyFrom(createExecArg())) - .addResources(memory) - .build() - } - - /** - * Create and serialize the executor argument to pass to Mesos. Our executor arg is an array - * containing all the spark.* system properties in the form of (String, String) pairs. - */ - private def createExecArg(): Array[Byte] = { - if (execArgs == null) { - val props = new HashMap[String, String] - val iterator = System.getProperties.entrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) { - props(key) = value - } - } - // Serialize the map as an array of (String, String) pairs - execArgs = Utils.serialize(props.toArray) - } - return execArgs - } - - private def setClassLoader(): ClassLoader = { - val oldClassLoader = Thread.currentThread.getContextClassLoader - Thread.currentThread.setContextClassLoader(classLoader) - return oldClassLoader - } - - private def restoreClassLoader(oldClassLoader: ClassLoader) { - Thread.currentThread.setContextClassLoader(oldClassLoader) - } - - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} - - override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - val oldClassLoader = setClassLoader() - try { - logInfo("Registered as framework ID " + frameworkId.getValue) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } finally { - restoreClassLoader(oldClassLoader) - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } - } - } - - override def disconnected(d: SchedulerDriver) {} - - override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} - - /** - * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets - * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that - * tasks are balanced across the cluster. - */ - override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - val oldClassLoader = setClassLoader() - try { - synchronized { - // Build a big list of the offerable workers, and remember their indices so that we can - // figure out which Offer to reply to for each worker - val offerableIndices = new ArrayBuffer[Int] - val offerableWorkers = new ArrayBuffer[WorkerOffer] - - def enoughMemory(o: Offer) = { - val mem = getResource(o.getResourcesList, "mem") - val slaveId = o.getSlaveId.getValue - mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId) - } - - for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { - offerableIndices += index - offerableWorkers += new WorkerOffer( - offer.getSlaveId.getValue, - offer.getHostname, - getResource(offer.getResourcesList, "cpus").toInt) - } - - // Call into the ClusterScheduler - val taskLists = scheduler.resourceOffers(offerableWorkers) - - // Build a list of Mesos tasks for each slave - val mesosTasks = offers.map(o => Collections.emptyList[MesosTaskInfo]()) - for ((taskList, index) <- taskLists.zipWithIndex) { - if (!taskList.isEmpty) { - val offerNum = offerableIndices(index) - val slaveId = offers(offerNum).getSlaveId.getValue - slaveIdsWithExecutors += slaveId - mesosTasks(offerNum) = new JArrayList[MesosTaskInfo](taskList.size) - for (taskDesc <- taskList) { - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId)) - } - } - } - - // Reply to the offers - val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? - for (i <- 0 until offers.size) { - d.launchTasks(offers(i).getId, mesosTasks(i), filters) - } - } - } finally { - restoreClassLoader(oldClassLoader) - } - } - - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) - } - - /** Turn a Spark TaskDescription into a Mesos task */ - def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { - val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() - val cpuResource = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(1).build()) - .build() - return MesosTaskInfo.newBuilder() - .setTaskId(taskId) - .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) - .setExecutor(createExecutorInfo(slaveId)) - .setName(task.name) - .addResources(cpuResource) - .setData(ByteString.copyFrom(task.serializedTask)) - .build() - } - - /** Check whether a Mesos task state represents a finished task */ - def isFinished(state: MesosTaskState) = { - state == MesosTaskState.TASK_FINISHED || - state == MesosTaskState.TASK_FAILED || - state == MesosTaskState.TASK_KILLED || - state == MesosTaskState.TASK_LOST - } - - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val oldClassLoader = setClassLoader() - try { - val tid = status.getTaskId.getValue.toLong - val state = TaskState.fromMesos(status.getState) - synchronized { - if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { - // We lost the executor on this slave, so remember that it's gone - slaveIdsWithExecutors -= taskIdToSlaveId(tid) - } - if (isFinished(status.getState)) { - taskIdToSlaveId.remove(tid) - } - } - scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer) - } finally { - restoreClassLoader(oldClassLoader) - } - } - - override def error(d: SchedulerDriver, message: String) { - val oldClassLoader = setClassLoader() - try { - logError("Mesos error: " + message) - scheduler.error(message) - } finally { - restoreClassLoader(oldClassLoader) - } - } - - override def stop() { - if (driver != null) { - driver.stop() - } - } - - override def reviveOffers() { - driver.reviveOffers() - } - - override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - - private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { - val oldClassLoader = setClassLoader() - try { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - slaveIdsWithExecutors -= slaveId.getValue - } - scheduler.executorLost(slaveId.getValue, reason) - } finally { - restoreClassLoader(oldClassLoader) - } - } - - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - recordSlaveLost(d, slaveId, SlaveLost()) - } - - override def executorLost(d: SchedulerDriver, executorId: ExecutorID, - slaveId: SlaveID, status: Int) { - logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, - slaveId.getValue)) - recordSlaveLost(d, slaveId, ExecutorExited(status)) - } - - // TODO: query Mesos for number of cores - override def defaultParallelism() = System.getProperty("spark.default.parallelism", "8").toInt -} diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala deleted file mode 100644 index dc94d42bb6..0000000000 --- a/core/src/main/scala/spark/serializer/Serializer.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.serializer - -import java.io.{EOFException, InputStream, OutputStream} -import java.nio.ByteBuffer - -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - -import spark.util.ByteBufferInputStream - - -/** - * A serializer. Because some serialization libraries are not thread safe, this class is used to - * create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are - * guaranteed to only be called from one thread at a time. - */ -trait Serializer { - def newInstance(): SerializerInstance -} - - -/** - * An instance of a serializer, for use by one thread at a time. - */ -trait SerializerInstance { - def serialize[T](t: T): ByteBuffer - - def deserialize[T](bytes: ByteBuffer): T - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T - - def serializeStream(s: OutputStream): SerializationStream - - def deserializeStream(s: InputStream): DeserializationStream - - def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { - // Default implementation uses serializeStream - val stream = new FastByteArrayOutputStream() - serializeStream(stream).writeAll(iterator) - val buffer = ByteBuffer.allocate(stream.position.toInt) - buffer.put(stream.array, 0, stream.position.toInt) - buffer.flip() - buffer - } - - def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - // Default implementation uses deserializeStream - buffer.rewind() - deserializeStream(new ByteBufferInputStream(buffer)).asIterator - } -} - - -/** - * A stream for writing serialized objects. - */ -trait SerializationStream { - def writeObject[T](t: T): SerializationStream - def flush(): Unit - def close(): Unit - - def writeAll[T](iter: Iterator[T]): SerializationStream = { - while (iter.hasNext) { - writeObject(iter.next()) - } - this - } -} - - -/** - * A stream for reading serialized objects. - */ -trait DeserializationStream { - def readObject[T](): T - def close(): Unit - - /** - * Read the elements of this stream through an iterator. This can only be called once, as - * reading each element will consume data from the input source. - */ - def asIterator: Iterator[Any] = new spark.util.NextIterator[Any] { - override protected def getNext() = { - try { - readObject[Any]() - } catch { - case eof: EOFException => - finished = true - } - } - - override protected def close() { - DeserializationStream.this.close() - } - } -} diff --git a/core/src/main/scala/spark/serializer/SerializerManager.scala b/core/src/main/scala/spark/serializer/SerializerManager.scala deleted file mode 100644 index b7b24705a2..0000000000 --- a/core/src/main/scala/spark/serializer/SerializerManager.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.serializer - -import java.util.concurrent.ConcurrentHashMap - - -/** - * A service that returns a serializer object given the serializer's class name. If a previous - * instance of the serializer object has been created, the get method returns that instead of - * creating a new one. - */ -private[spark] class SerializerManager { - - private val serializers = new ConcurrentHashMap[String, Serializer] - private var _default: Serializer = _ - - def default = _default - - def setDefault(clsName: String): Serializer = { - _default = get(clsName) - _default - } - - def get(clsName: String): Serializer = { - if (clsName == null) { - default - } else { - var serializer = serializers.get(clsName) - if (serializer != null) { - // If the serializer has been created previously, reuse that. - serializer - } else this.synchronized { - // Otherwise, create a new one. But make sure no other thread has attempted - // to create another new one at the same time. - serializer = serializers.get(clsName) - if (serializer == null) { - val clsLoader = Thread.currentThread.getContextClassLoader - serializer = - Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer] - serializers.put(clsName, serializer) - } - serializer - } - } - } -} diff --git a/core/src/main/scala/spark/storage/BlockException.scala b/core/src/main/scala/spark/storage/BlockException.scala deleted file mode 100644 index 8ebfaf3cbf..0000000000 --- a/core/src/main/scala/spark/storage/BlockException.scala +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -private[spark] -case class BlockException(blockId: String, message: String) extends Exception(message) - diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala deleted file mode 100644 index 265e554ad8..0000000000 --- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -private[spark] trait BlockFetchTracker { - def totalBlocks : Int - def numLocalBlocks: Int - def numRemoteBlocks: Int - def remoteFetchTime : Long - def fetchWaitTime: Long - def remoteBytesRead : Long -} diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala deleted file mode 100644 index 568783d893..0000000000 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ /dev/null @@ -1,348 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer -import java.util.concurrent.LinkedBlockingQueue - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue - -import io.netty.buffer.ByteBuf - -import spark.Logging -import spark.Utils -import spark.SparkException -import spark.network.BufferMessage -import spark.network.ConnectionManagerId -import spark.network.netty.ShuffleCopier -import spark.serializer.Serializer - - -/** - * A block fetcher iterator interface. There are two implementations: - * - * BasicBlockFetcherIterator: uses a custom-built NIO communication layer. - * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer. - * - * Eventually we would like the two to converge and use a single NIO-based communication layer, - * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores), - * NIO would perform poorly and thus the need for the Netty OIO one. - */ - -private[storage] -trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] - with Logging with BlockFetchTracker { - def initialize() -} - - -private[storage] -object BlockFetcherIterator { - - // A request to fetch one or more blocks, complete with their sizes - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { - val size = blocks.map(_._2).sum - } - - // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize - // the block (since we want all deserializaton to happen in the calling thread); can also - // represent a fetch failure if size == -1. - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } - - class BasicBlockFetcherIterator( - private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer) - extends BlockFetcherIterator { - - import blockManager._ - - private var _remoteBytesRead = 0l - private var _remoteFetchTime = 0l - private var _fetchWaitTime = 0l - - if (blocksByAddress == null) { - throw new IllegalArgumentException("BlocksByAddress is null") - } - - // Total number blocks fetched (local + remote). Also number of FetchResults expected - protected var _numBlocksToFetch = 0 - - protected var startTime = System.currentTimeMillis - - // This represents the number of local blocks, also counting zero-sized blocks - private var numLocal = 0 - // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks - protected val localBlocksToFetch = new ArrayBuffer[String]() - - // This represents the number of remote blocks, also counting zero-sized blocks - private var numRemote = 0 - // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks - protected val remoteBlocksToFetch = new HashSet[String]() - - // A queue to hold our results. - protected val results = new LinkedBlockingQueue[FetchResult] - - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight - private val fetchRequests = new Queue[FetchRequest] - - // Current bytes in flight from our requests - private var bytesInFlight = 0L - - protected def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - val blockMessageArray = new BlockMessageArray(req.blocks.map { - case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) - }) - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - val fetchStart = System.currentTimeMillis() - val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onSuccess { - case Some(message) => { - val fetchDone = System.currentTimeMillis() - _remoteFetchTime += fetchDone - fetchStart - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += networkSize - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - case None => { - logError("Could not get block(s) from " + cmId) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - } - } - - protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - numLocal = blockInfos.size - // Filter out zero-sized blocks - localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) - _numBlocksToFetch += localBlocksToFetch.size - } else { - numRemote += blockInfos.size - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { - curBlocks += ((blockId, size)) - remoteBlocksToFetch += blockId - _numBlocksToFetch += 1 - curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) - } - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " + - totalBlocks + " blocks") - remoteRequests - } - - protected def getLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocksToFetch) { - getLocalFromDisk(id, serializer) match { - case Some(iter) => { - // Pass 0 as size since it's not in flight - results.put(new FetchResult(id, 0, () => iter)) - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") - } - } - } - } - - override def initialize() { - // Split local and remote blocks. - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - - val numGets = remoteRequests.size - fetchRequests.size - logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - //an iterator that will read fetched blocks off the queue as they arrive. - @volatile protected var resultsGotten = 0 - - override def hasNext: Boolean = resultsGotten < _numBlocksToFetch - - override def next(): (String, Option[Iterator[Any]]) = { - resultsGotten += 1 - val startFetchWait = System.currentTimeMillis() - val result = results.take() - val stopFetchWait = System.currentTimeMillis() - _fetchWaitTime += (stopFetchWait - startFetchWait) - if (! result.failed) bytesInFlight -= result.size - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - - // Implementing BlockFetchTracker trait. - override def totalBlocks: Int = numLocal + numRemote - override def numLocalBlocks: Int = numLocal - override def numRemoteBlocks: Int = numRemote - override def remoteFetchTime: Long = _remoteFetchTime - override def fetchWaitTime: Long = _fetchWaitTime - override def remoteBytesRead: Long = _remoteBytesRead - } - // End of BasicBlockFetcherIterator - - class NettyBlockFetcherIterator( - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer) - extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { - - import blockManager._ - - val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] - - private def startCopiers(numCopiers: Int): List[_ <: Thread] = { - (for ( i <- Range(0,numCopiers) ) yield { - val copier = new Thread { - override def run(){ - try { - while(!isInterrupted && !fetchRequestsSync.isEmpty) { - sendRequest(fetchRequestsSync.take()) - } - } catch { - case x: InterruptedException => logInfo("Copier Interrupted") - //case _ => throw new SparkException("Exception Throw in Shuffle Copier") - } - } - } - copier.start - copier - }).toList - } - - // keep this to interrupt the threads when necessary - private def stopCopiers() { - for (copier <- copiers) { - copier.interrupt() - } - } - - override protected def sendRequest(req: FetchRequest) { - - def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { - val fetchResult = new FetchResult(blockId, blockSize, - () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) - results.put(fetchResult) - } - - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.host)) - val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) - val cpier = new ShuffleCopier - cpier.getBlocks(cmId, req.blocks, putResult) - logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) - } - - private var copiers: List[_ <: Thread] = null - - override def initialize() { - // Split Local Remote Blocks and set numBlocksToFetch - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - for (request <- Utils.randomize(remoteRequests)) { - fetchRequestsSync.put(request) - } - - copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) - logInfo("Started " + fetchRequestsSync.size + " remote gets in " + - Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - override def next(): (String, Option[Iterator[Any]]) = { - resultsGotten += 1 - val result = results.take() - // If all the results has been retrieved, copiers will exit automatically - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - } - // End of NettyBlockFetcherIterator -} diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala deleted file mode 100644 index 2a6ec2a55d..0000000000 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ /dev/null @@ -1,1046 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.io.{InputStream, OutputStream} -import java.nio.{ByteBuffer, MappedByteBuffer} - -import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} - -import akka.actor.{ActorSystem, Cancellable, Props} -import akka.dispatch.{Await, Future} -import akka.util.Duration -import akka.util.duration._ - -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - -import spark.{Logging, SparkEnv, SparkException, Utils} -import spark.io.CompressionCodec -import spark.network._ -import spark.serializer.Serializer -import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} - -import sun.nio.ch.DirectBuffer - - -private[spark] class BlockManager( - executorId: String, - actorSystem: ActorSystem, - val master: BlockManagerMaster, - val defaultSerializer: Serializer, - maxMemory: Long) - extends Logging { - - private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - @volatile var pending: Boolean = true - @volatile var size: Long = -1L - @volatile var initThread: Thread = null - @volatile var failed = false - - setInitThread() - - private def setInitThread() { - // Set current thread as init thread - waitForReady will not block this thread - // (in case there is non trivial initialization which ends up calling waitForReady as part of - // initialization itself) - this.initThread = Thread.currentThread() - } - - /** - * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). - * Return true if the block is available, false otherwise. - */ - def waitForReady(): Boolean = { - if (initThread != Thread.currentThread() && pending) { - synchronized { - while (pending) this.wait() - } - } - !failed - } - - /** Mark this BlockInfo as ready (i.e. block is finished writing) */ - def markReady(sizeInBytes: Long) { - assert (pending) - size = sizeInBytes - initThread = null - failed = false - initThread = null - pending = false - synchronized { - this.notifyAll() - } - } - - /** Mark this BlockInfo as ready but failed */ - def markFailure() { - assert (pending) - size = 0 - initThread = null - failed = true - initThread = null - pending = false - synchronized { - this.notifyAll() - } - } - } - - val shuffleBlockManager = new ShuffleBlockManager(this) - - private val blockInfo = new TimeStampedHashMap[String, BlockInfo] - - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) - private[storage] val diskStore: DiskStore = - new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) - - // If we use Netty for shuffle, start a new Netty-based shuffle sender service. - private val nettyPort: Int = { - val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean - val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt - if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0 - } - - val connectionManager = new ConnectionManager(0) - implicit val futureExecContext = connectionManager.futureExecContext - - val blockManagerId = BlockManagerId( - executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) - - // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory - // for receiving shuffle outputs) - val maxBytesInFlight = - System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 - - // Whether to compress broadcast variables that are stored - val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean - // Whether to compress shuffle output that are stored - val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean - // Whether to compress RDD partitions that are stored serialized - val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean - - val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties - - val hostPort = Utils.localHostPort() - - val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) - - // Pending reregistration action being executed asynchronously or null if none - // is pending. Accesses should synchronize on asyncReregisterLock. - var asyncReregisterTask: Future[Unit] = null - val asyncReregisterLock = new Object - - private def heartBeat() { - if (!master.sendHeartBeat(blockManagerId)) { - reregister() - } - } - - var heartBeatTask: Cancellable = null - - val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks) - initialize() - - // The compression codec to use. Note that the "lazy" val is necessary because we want to delay - // the initialization of the compression codec until it is first used. The reason is that a Spark - // program could be using a user-defined codec in a third party jar, which is loaded in - // Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been - // loaded yet. - private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec() - - /** - * Construct a BlockManager with a memory limit set based on system properties. - */ - def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) - } - - /** - * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. - */ - private def initialize() { - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) - BlockManagerWorker.startBlockManagerWorker(this) - if (!BlockManager.getDisableHeartBeatsForTesting) { - heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { - heartBeat() - } - } - } - - /** - * Report all blocks to the BlockManager again. This may be necessary if we are dropped - * by the BlockManager and come back or if we become capable of recovering blocks on disk after - * an executor crash. - * - * This function deliberately fails silently if the master returns false (indicating that - * the slave needs to reregister). The error condition will be detected again by the next - * heart beat attempt or new block registration and another try to reregister all blocks - * will be made then. - */ - private def reportAllBlocks() { - logInfo("Reporting " + blockInfo.size + " blocks to the master.") - for ((blockId, info) <- blockInfo) { - if (!tryToReportBlockStatus(blockId, info)) { - logError("Failed to report " + blockId + " to master; giving up.") - return - } - } - } - - /** - * Reregister with the master and report all blocks to it. This will be called by the heart beat - * thread if our heartbeat to the block amnager indicates that we were not registered. - * - * Note that this method must be called without any BlockInfo locks held. - */ - def reregister() { - // TODO: We might need to rate limit reregistering. - logInfo("BlockManager reregistering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) - reportAllBlocks() - } - - /** - * Reregister with the master sometime soon. - */ - def asyncReregister() { - asyncReregisterLock.synchronized { - if (asyncReregisterTask == null) { - asyncReregisterTask = Future[Unit] { - reregister() - asyncReregisterLock.synchronized { - asyncReregisterTask = null - } - } - } - } - } - - /** - * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing. - */ - def waitForAsyncReregister() { - val task = asyncReregisterTask - if (task != null) { - Await.ready(task, Duration.Inf) - } - } - - /** - * Get storage level of local block. If no info exists for the block, then returns null. - */ - def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull - - /** - * Tell the master about the current storage status of a block. This will send a block update - * message reflecting the current status, *not* the desired storage level in its block info. - * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. - * - * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). - * This ensures that update in master will compensate for the increase in memory on slave. - */ - def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { - val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) - if (needReregister) { - logInfo("Got told to reregister updating block " + blockId) - // Reregistering will report our new block for free. - asyncReregister() - } - logDebug("Told master about block " + blockId) - } - - /** - * Actually send a UpdateBlockInfo message. Returns the mater's response, - * which will be true if the block was successfully recorded and false if - * the slave needs to re-register. - */ - private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { - val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { - info.level match { - case null => - (StorageLevel.NONE, 0L, 0L, false) - case level => - val inMem = level.useMemory && memoryStore.contains(blockId) - val onDisk = level.useDisk && diskStore.contains(blockId) - val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) - val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize - val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L - (storageLevel, memSize, diskSize, info.tellMaster) - } - } - - if (tellMaster) { - master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize) - } else { - true - } - } - - /** - * Get locations of an array of blocks. - */ - def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = { - val startTimeMs = System.currentTimeMillis - val locations = master.getLocations(blockIds).toArray - logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) - locations - } - - /** - * A short-circuited method to get blocks directly from disk. This is used for getting - * shuffle blocks. It is safe to do so without a lock on block info since disk store - * never deletes (recent) items. - */ - def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { - diskStore.getValues(blockId, serializer).orElse( - sys.error("Block " + blockId + " not found on disk, though it should be")) - } - - /** - * Get block from local block manager. - */ - def getLocal(blockId: String): Option[Iterator[Any]] = { - logDebug("Getting local block " + blockId) - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - - // In the another thread is writing the block, wait for it to become ready. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure.") - return None - } - - val level = info.level - logDebug("Level for block " + blockId + " is " + level) - - // Look for the block in memory - if (level.useMemory) { - logDebug("Getting block " + blockId + " from memory") - memoryStore.getValues(blockId) match { - case Some(iterator) => - return Some(iterator) - case None => - logDebug("Block " + blockId + " not found in memory") - } - } - - // Look for block on disk, potentially loading it back into memory if required - if (level.useDisk) { - logDebug("Getting block " + blockId + " from disk") - if (level.useMemory && level.deserialized) { - diskStore.getValues(blockId) match { - case Some(iterator) => - // Put the block back in memory before returning it - // TODO: Consider creating a putValues that also takes in a iterator ? - val elements = new ArrayBuffer[Any] - elements ++= iterator - memoryStore.putValues(blockId, elements, level, true).data match { - case Left(iterator2) => - return Some(iterator2) - case _ => - throw new Exception("Memory store did not return back an iterator") - } - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } else if (level.useMemory && !level.deserialized) { - // Read it as a byte buffer into memory first, then return it - diskStore.getBytes(blockId) match { - case Some(bytes) => - // Put a copy of the block back in memory before returning it. Note that we can't - // put the ByteBuffer returned by the disk store as that's a memory-mapped file. - // The use of rewind assumes this. - assert (0 == bytes.position()) - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) - bytes.rewind() - return Some(dataDeserialize(blockId, bytes)) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } else { - diskStore.getValues(blockId) match { - case Some(iterator) => - return Some(iterator) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - } - } - } else { - logDebug("Block " + blockId + " not registered locally") - } - return None - } - - /** - * Get block from the local block manager as serialized bytes. - */ - def getLocalBytes(blockId: String): Option[ByteBuffer] = { - // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow - logDebug("Getting local block " + blockId + " as bytes") - - // As an optimization for map output fetches, if the block is for a shuffle, return it - // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (ShuffleBlockManager.isShuffle(blockId)) { - return diskStore.getBytes(blockId) match { - case Some(bytes) => - Some(bytes) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - - // In the another thread is writing the block, wait for it to become ready. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure.") - return None - } - - val level = info.level - logDebug("Level for block " + blockId + " is " + level) - - // Look for the block in memory - if (level.useMemory) { - logDebug("Getting block " + blockId + " from memory") - memoryStore.getBytes(blockId) match { - case Some(bytes) => - return Some(bytes) - case None => - logDebug("Block " + blockId + " not found in memory") - } - } - - // Look for block on disk - if (level.useDisk) { - // Read it as a byte buffer into memory first, then return it - diskStore.getBytes(blockId) match { - case Some(bytes) => - assert (0 == bytes.position()) - if (level.useMemory) { - if (level.deserialized) { - memoryStore.putBytes(blockId, bytes, level) - } else { - // The memory store will hang onto the ByteBuffer, so give it a copy instead of - // the memory-mapped file buffer we got from the disk store - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) - } - } - bytes.rewind() - return Some(bytes) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - } - } else { - logDebug("Block " + blockId + " not registered locally") - } - return None - } - - /** - * Get block from remote block managers. - */ - def getRemote(blockId: String): Option[Iterator[Any]] = { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - logDebug("Getting remote block " + blockId) - // Get locations of block - val locations = master.getLocations(blockId) - - // Get block from remote locations - for (loc <- locations) { - logDebug("Getting remote block " + blockId + " from " + loc) - val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) - if (data != null) { - return Some(dataDeserialize(blockId, data)) - } - logDebug("The value of block " + blockId + " is null") - } - logDebug("Block " + blockId + " not found") - return None - } - - /** - * Get a block from the block manager (either local or remote). - */ - def get(blockId: String): Option[Iterator[Any]] = { - getLocal(blockId).orElse(getRemote(blockId)) - } - - /** - * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns - * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined - * fashion as they're received. Expects a size in bytes to be provided for each block fetched, - * so that we can control the maxMegabytesInFlight for the fetch. - */ - def getMultiple( - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) - : BlockFetcherIterator = { - - val iter = - if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) { - new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) - } else { - new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) - } - - iter.initialize() - iter - } - - def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) - : Long = { - val elements = new ArrayBuffer[Any] - elements ++= values - put(blockId, elements, level, tellMaster) - } - - /** - * A short circuited method to get a block writer that can write data directly to disk. - * This is currently used for writing shuffle files out. Callers should handle error - * cases. - */ - def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) - : BlockObjectWriter = { - val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) - writer.registerCloseEventHandler(() => { - val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) - blockInfo.put(blockId, myInfo) - myInfo.markReady(writer.size()) - }) - writer - } - - /** - * Put a new block of values to the block manager. Returns its (estimated) size in bytes. - */ - def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, - tellMaster: Boolean = true) : Long = { - - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (values == null) { - throw new IllegalArgumentException("Values is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } - - // Remember the block's storage level so that we can correctly drop it to disk if it needs - // to be dropped right after it got put into memory. Note, however, that other threads will - // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = { - val tinfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) - - if (oldBlockOpt.isDefined) { - if (oldBlockOpt.get.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return oldBlockOpt.get.size - } - - // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? - oldBlockOpt.get - } else { - tinfo - } - } - - val startTimeMs = System.currentTimeMillis - - // If we need to replicate the data, we'll want access to the values, but because our - // put will read the whole iterator, there will be no values left. For the case where - // the put serializes data, we'll remember the bytes, above; but for the case where it - // doesn't, such as deserialized storage, let's rely on the put returning an Iterator. - var valuesAfterPut: Iterator[Any] = null - - // Ditto for the bytes after the put - var bytesAfterPut: ByteBuffer = null - - // Size of the block in bytes (to return to caller) - var size = 0L - - myInfo.synchronized { - logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") - - var marked = false - try { - if (level.useMemory) { - // Save it just to memory first, even if it also has useDisk set to true; we will later - // drop it to disk if the memory store can't hold it. - val res = memoryStore.putValues(blockId, values, level, true) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case Left(newIterator) => valuesAfterPut = newIterator - } - } else { - // Save directly to disk. - // Don't get back the bytes unless we replicate them. - val askForBytes = level.replication > 1 - val res = diskStore.putValues(blockId, values, level, askForBytes) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => - } - } - - // Now that the block is in either the memory or disk store, let other threads read it, - // and tell the master about it. - marked = true - myInfo.markReady(size) - if (tellMaster) { - reportBlockStatus(blockId, myInfo) - } - } finally { - // If we failed at putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (! marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - myInfo.markFailure() - logWarning("Putting block " + blockId + " failed") - } - } - } - logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) - - // Replicate block if required - if (level.replication > 1) { - val remoteStartTime = System.currentTimeMillis - // Serialize the block if not already done - if (bytesAfterPut == null) { - if (valuesAfterPut == null) { - throw new SparkException( - "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") - } - bytesAfterPut = dataSerialize(blockId, valuesAfterPut) - } - replicate(blockId, bytesAfterPut, level) - logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime)) - } - BlockManager.dispose(bytesAfterPut) - - return size - } - - - /** - * Put a new block of serialized bytes to the block manager. - */ - def putBytes( - blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { - - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (bytes == null) { - throw new IllegalArgumentException("Bytes is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } - - // Remember the block's storage level so that we can correctly drop it to disk if it needs - // to be dropped right after it got put into memory. Note, however, that other threads will - // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = { - val tinfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) - - if (oldBlockOpt.isDefined) { - if (oldBlockOpt.get.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return - } - - // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? - oldBlockOpt.get - } else { - tinfo - } - } - - val startTimeMs = System.currentTimeMillis - - // Initiate the replication before storing it locally. This is faster as - // data is already serialized and ready for sending - val replicationFuture = if (level.replication > 1) { - val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper - Future { - replicate(blockId, bufferView, level) - } - } else { - null - } - - myInfo.synchronized { - logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") - - var marked = false - try { - if (level.useMemory) { - // Store it only in memory at first, even if useDisk is also set to true - bytes.rewind() - memoryStore.putBytes(blockId, bytes, level) - } else { - bytes.rewind() - diskStore.putBytes(blockId, bytes, level) - } - - // assert (0 == bytes.position(), "" + bytes) - - // Now that the block is in either the memory or disk store, let other threads read it, - // and tell the master about it. - marked = true - myInfo.markReady(bytes.limit) - if (tellMaster) { - reportBlockStatus(blockId, myInfo) - } - } finally { - // If we failed at putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (! marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - myInfo.markFailure() - logWarning("Putting block " + blockId + " failed") - } - } - } - - // If replication had started, then wait for it to finish - if (level.replication > 1) { - Await.ready(replicationFuture, Duration.Inf) - } - - if (level.replication > 1) { - logDebug("PutBytes for block " + blockId + " with replication took " + - Utils.getUsedTimeMs(startTimeMs)) - } else { - logDebug("PutBytes for block " + blockId + " without replication took " + - Utils.getUsedTimeMs(startTimeMs)) - } - } - - /** - * Replicate block to another node. - */ - var cachedPeers: Seq[BlockManagerId] = null - private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { - val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - if (cachedPeers == null) { - cachedPeers = master.getPeers(blockManagerId, level.replication - 1) - } - for (peer: BlockManagerId <- cachedPeers) { - val start = System.nanoTime - data.rewind() - logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " - + data.limit() + " Bytes. To node: " + peer) - if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), - new ConnectionManagerId(peer.host, peer.port))) { - logError("Failed to call syncPutBlock to " + peer) - } - logDebug("Replicated BlockId " + blockId + " once used " + - (System.nanoTime - start) / 1e6 + " s; The size of the data is " + - data.limit() + " bytes.") - } - } - - /** - * Read a block consisting of a single object. - */ - def getSingle(blockId: String): Option[Any] = { - get(blockId).map(_.next()) - } - - /** - * Write a block consisting of a single object. - */ - def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) { - put(blockId, Iterator(value), level, tellMaster) - } - - /** - * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory - * store reaches its limit and needs to free up space. - */ - def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { - logInfo("Dropping block " + blockId + " from memory") - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - // required ? As of now, this will be invoked only for blocks which are ready - // But in case this changes in future, adding for consistency sake. - if (! info.waitForReady() ) { - // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure. Nothing to drop") - return - } - - val level = info.level - if (level.useDisk && !diskStore.contains(blockId)) { - logInfo("Writing block " + blockId + " to disk") - data match { - case Left(elements) => - diskStore.putValues(blockId, elements, level, false) - case Right(bytes) => - diskStore.putBytes(blockId, bytes, level) - } - } - val droppedMemorySize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L - val blockWasRemoved = memoryStore.remove(blockId) - if (!blockWasRemoved) { - logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") - } - if (info.tellMaster) { - reportBlockStatus(blockId, info, droppedMemorySize) - } - if (!level.useDisk) { - // The block is completely gone from this node; forget it so we can put() it again later. - blockInfo.remove(blockId) - } - } - } else { - // The block has already been dropped - } - } - - /** - * Remove all blocks belonging to the given RDD. - * @return The number of blocks removed. - */ - def removeRdd(rddId: Int): Int = { - // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps - // from RDD.id to blocks. - logInfo("Removing RDD " + rddId) - val rddPrefix = "rdd_" + rddId + "_" - val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1) - blocksToRemove.foreach(blockId => removeBlock(blockId, false)) - blocksToRemove.size - } - - /** - * Remove a block from both memory and disk. - */ - def removeBlock(blockId: String, tellMaster: Boolean = true) { - logInfo("Removing block " + blockId) - val info = blockInfo.get(blockId).orNull - if (info != null) info.synchronized { - // Removals are idempotent in disk store and memory store. At worst, we get a warning. - val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) - if (!removedFromMemory && !removedFromDisk) { - logWarning("Block " + blockId + " could not be removed as it was not found in either " + - "the disk or memory store") - } - blockInfo.remove(blockId) - if (tellMaster && info.tellMaster) { - reportBlockStatus(blockId, info) - } - } else { - // The block has already been removed; do nothing. - logWarning("Asked to remove block " + blockId + ", which does not exist") - } - } - - def dropOldBlocks(cleanupTime: Long) { - logInfo("Dropping blocks older than " + cleanupTime) - val iterator = blockInfo.internalMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) - if (time < cleanupTime) { - info.synchronized { - val level = info.level - if (level.useMemory) { - memoryStore.remove(id) - } - if (level.useDisk) { - diskStore.remove(id) - } - iterator.remove() - logInfo("Dropped block " + id) - } - reportBlockStatus(id, info) - } - } - } - - def shouldCompress(blockId: String): Boolean = { - if (ShuffleBlockManager.isShuffle(blockId)) { - compressShuffle - } else if (blockId.startsWith("broadcast_")) { - compressBroadcast - } else if (blockId.startsWith("rdd_")) { - compressRdds - } else { - false // Won't happen in a real cluster, but it can in tests - } - } - - /** - * Wrap an output stream for compression if block compression is enabled for its block type - */ - def wrapForCompression(blockId: String, s: OutputStream): OutputStream = { - if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s - } - - /** - * Wrap an input stream for compression if block compression is enabled for its block type - */ - def wrapForCompression(blockId: String, s: InputStream): InputStream = { - if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s - } - - def dataSerialize( - blockId: String, - values: Iterator[Any], - serializer: Serializer = defaultSerializer): ByteBuffer = { - val byteStream = new FastByteArrayOutputStream(4096) - val ser = serializer.newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() - byteStream.trim() - ByteBuffer.wrap(byteStream.array) - } - - /** - * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of - * the iterator is reached. - */ - def dataDeserialize( - blockId: String, - bytes: ByteBuffer, - serializer: Serializer = defaultSerializer): Iterator[Any] = { - bytes.rewind() - val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) - serializer.newInstance().deserializeStream(stream).asIterator - } - - def stop() { - if (heartBeatTask != null) { - heartBeatTask.cancel() - } - connectionManager.stop() - actorSystem.stop(slaveActor) - blockInfo.clear() - memoryStore.clear() - diskStore.clear() - metadataCleaner.cancel() - logInfo("BlockManager stopped") - } -} - - -private[spark] object BlockManager extends Logging { - - val ID_GENERATOR = new IdGenerator - - def getMaxMemoryFromSystemProperties: Long = { - val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble - (Runtime.getRuntime.maxMemory * memoryFraction).toLong - } - - def getHeartBeatFrequencyFromSystemProperties: Long = - System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 - - def getDisableHeartBeatsForTesting: Boolean = - System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean - - /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. - */ - def dispose(buffer: ByteBuffer) { - if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logTrace("Unmapping " + buffer) - if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { - buffer.asInstanceOf[DirectBuffer].cleaner().clean() - } - } - } - - def blockIdsToBlockManagers( - blockIds: Array[String], - env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[BlockManagerId]] = - { - // env == null and blockManagerMaster != null is used in tests - assert (env != null || blockManagerMaster != null) - val blockLocations: Seq[Seq[BlockManagerId]] = if (env != null) { - env.blockManager.getLocationBlockIds(blockIds) - } else { - blockManagerMaster.getLocations(blockIds) - } - - val blockManagers = new HashMap[String, Seq[BlockManagerId]] - for (i <- 0 until blockIds.length) { - blockManagers(blockIds(i)) = blockLocations(i) - } - blockManagers.toMap - } - - def blockIdsToExecutorIds( - blockIds: Array[String], - env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[String]] = - { - blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId)) - } - - def blockIdsToHosts( - blockIds: Array[String], - env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[String]] = - { - blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) - } -} - diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala deleted file mode 100644 index b36a6176c0..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.concurrent.ConcurrentHashMap -import spark.Utils - -/** - * This class represent an unique identifier for a BlockManager. - * The first 2 constructors of this class is made private to ensure that - * BlockManagerId objects can be created only using the apply method in - * the companion object. This allows de-duplication of ID objects. - * Also, constructor parameters are private to ensure that parameters cannot - * be modified from outside this class. - */ -private[spark] class BlockManagerId private ( - private var executorId_ : String, - private var host_ : String, - private var port_ : Int, - private var nettyPort_ : Int - ) extends Externalizable { - - private def this() = this(null, null, 0, 0) // For deserialization only - - def executorId: String = executorId_ - - if (null != host_){ - Utils.checkHost(host_, "Expected hostname") - assert (port_ > 0) - } - - def hostPort: String = { - // DEBUG code - Utils.checkHost(host) - assert (port > 0) - - host + ":" + port - } - - def host: String = host_ - - def port: Int = port_ - - def nettyPort: Int = nettyPort_ - - override def writeExternal(out: ObjectOutput) { - out.writeUTF(executorId_) - out.writeUTF(host_) - out.writeInt(port_) - out.writeInt(nettyPort_) - } - - override def readExternal(in: ObjectInput) { - executorId_ = in.readUTF() - host_ = in.readUTF() - port_ = in.readInt() - nettyPort_ = in.readInt() - } - - @throws(classOf[IOException]) - private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - - override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort) - - override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort - - override def equals(that: Any) = that match { - case id: BlockManagerId => - executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort - case _ => - false - } -} - - -private[spark] object BlockManagerId { - - /** - * Returns a [[spark.storage.BlockManagerId]] for the given configuraiton. - * - * @param execId ID of the executor. - * @param host Host name of the block manager. - * @param port Port of the block manager. - * @param nettyPort Optional port for the Netty-based shuffle sender. - * @return A new [[spark.storage.BlockManagerId]]. - */ - def apply(execId: String, host: String, port: Int, nettyPort: Int) = - getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort)) - - def apply(in: ObjectInput) = { - val obj = new BlockManagerId() - obj.readExternal(in) - getCachedBlockManagerId(obj) - } - - val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() - - def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - blockManagerIdCache.putIfAbsent(id, id) - blockManagerIdCache.get(id) - } -} diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala deleted file mode 100644 index 76128e8cff..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import akka.actor.ActorRef -import akka.dispatch.{Await, Future} -import akka.pattern.ask -import akka.util.Duration - -import spark.{Logging, SparkException} -import spark.storage.BlockManagerMessages._ - - -private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging { - - val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt - val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt - - val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" - - val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - - /** Remove a dead executor from the driver actor. This is only called on the driver side. */ - def removeExecutor(execId: String) { - tell(RemoveExecutor(execId)) - logInfo("Removed " + execId + " successfully in removeExecutor") - } - - /** - * Send the driver actor a heart beat from the slave. Returns true if everything works out, - * false if the driver does not know about the given block manager, which means the block - * manager should re-register. - */ - def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { - askDriverWithReply[Boolean](HeartBeat(blockManagerId)) - } - - /** Register the BlockManager's id with the driver. */ - def registerBlockManager( - blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) - logInfo("Registered BlockManager") - } - - def updateBlockInfo( - blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long): Boolean = { - val res = askDriverWithReply[Boolean]( - UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) - logInfo("Updated info of block " + blockId) - res - } - - /** Get locations of the blockId from the driver */ - def getLocations(blockId: String): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) - } - - /** Get locations of multiple blockIds from the driver */ - def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) - } - - /** Get ids of other nodes in the cluster from the driver */ - def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { - val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) - if (result.length != numPeers) { - throw new SparkException( - "Error getting peers, only got " + result.size + " instead of " + numPeers) - } - result - } - - /** - * Remove a block from the slaves that have it. This can only be used to remove - * blocks that the driver knows about. - */ - def removeBlock(blockId: String) { - askDriverWithReply(RemoveBlock(blockId)) - } - - /** - * Remove all blocks belonging to the given RDD. - */ - def removeRdd(rddId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) - future onFailure { - case e: Throwable => logError("Failed to remove RDD " + rddId, e) - } - if (blocking) { - Await.result(future, timeout) - } - } - - /** - * Return the memory status for each block manager, in the form of a map from - * the block manager's id to two long values. The first value is the maximum - * amount of memory allocated for the block manager, while the second is the - * amount of remaining memory. - */ - def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) - } - - def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[Array[StorageStatus]](GetStorageStatus) - } - - /** Stop the driver actor, called only on the Spark driver node */ - def stop() { - if (driverActor != null) { - tell(StopBlockManagerMaster) - driverActor = null - logInfo("BlockManagerMaster stopped") - } - } - - /** Send a one-way message to the master actor, to which we expect it to reply with true. */ - private def tell(message: Any) { - if (!askDriverWithReply[Boolean](message)) { - throw new SparkException("BlockManagerMasterActor returned false, expected true.") - } - } - - /** - * Send a message to the driver actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - private def askDriverWithReply[T](message: Any): T = { - // TODO: Consider removing multiple attempts - if (driverActor == null) { - throw new SparkException("Error sending message to BlockManager as driverActor is null " + - "[message = " + message + "]") - } - var attempts = 0 - var lastException: Exception = null - while (attempts < AKKA_RETRY_ATTEMPTS) { - attempts += 1 - try { - val future = driverActor.ask(message)(timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("BlockManagerMaster returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e) - } - Thread.sleep(AKKA_RETRY_INTERVAL_MS) - } - - throw new SparkException( - "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) - } - -} diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala deleted file mode 100644 index b7a981d101..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.util.{HashMap => JHashMap} - -import scala.collection.mutable -import scala.collection.JavaConversions._ - -import akka.actor.{Actor, ActorRef, Cancellable} -import akka.dispatch.Future -import akka.pattern.ask -import akka.util.Duration -import akka.util.duration._ - -import spark.{Logging, Utils, SparkException} -import spark.storage.BlockManagerMessages._ - - -/** - * BlockManagerMasterActor is an actor on the master node to track statuses of - * all slaves' block managers. - */ -private[spark] -class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { - - // Mapping from block manager id to the block manager's information. - private val blockManagerInfo = - new mutable.HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] - - // Mapping from executor ID to block manager ID. - private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] - - // Mapping from block id to the set of block managers that have the block. - private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]] - - val akkaTimeout = Duration.create( - System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - - initLogging() - - val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs", - "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong - - val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", - "60000").toLong - - var timeoutCheckingTask: Cancellable = null - - override def preStart() { - if (!BlockManager.getDisableHeartBeatsForTesting) { - timeoutCheckingTask = context.system.scheduler.schedule( - 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - } - super.preStart() - } - - def receive = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - sender ! true - - case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => - // TODO: Ideally we want to handle all the message replies in receive instead of in the - // individual private methods. - updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) - - case GetLocations(blockId) => - sender ! getLocations(blockId) - - case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) - - case GetPeers(blockManagerId, size) => - sender ! getPeers(blockManagerId, size) - - case GetMemoryStatus => - sender ! memoryStatus - - case GetStorageStatus => - sender ! storageStatus - - case RemoveRdd(rddId) => - sender ! removeRdd(rddId) - - case RemoveBlock(blockId) => - removeBlockFromWorkers(blockId) - sender ! true - - case RemoveExecutor(execId) => - removeExecutor(execId) - sender ! true - - case StopBlockManagerMaster => - logInfo("Stopping BlockManagerMaster") - sender ! true - if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() - } - context.stop(self) - - case ExpireDeadHosts => - expireDeadHosts() - - case HeartBeat(blockManagerId) => - sender ! heartBeat(blockManagerId) - - case other => - logWarning("Got unknown message: " + other) - } - - private def removeRdd(rddId: Int): Future[Seq[Int]] = { - // First remove the metadata for the given RDD, and then asynchronously remove the blocks - // from the slaves. - - val prefix = "rdd_" + rddId + "_" - // Find all blocks for the given RDD, remove the block from both blockLocations and - // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keySet().filter(_.startsWith(prefix)) - blocks.foreach { blockId => - val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) - bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) - blockLocations.remove(blockId) - } - - // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. - // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher - val removeMsg = RemoveRdd(rddId) - Future.sequence(blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq) - } - - private def removeBlockManager(blockManagerId: BlockManagerId) { - val info = blockManagerInfo(blockManagerId) - - // Remove the block manager from blockManagerIdByExecutor. - blockManagerIdByExecutor -= blockManagerId.executorId - - // Remove it from blockManagerInfo and remove all the blocks. - blockManagerInfo.remove(blockManagerId) - val iterator = info.blocks.keySet.iterator - while (iterator.hasNext) { - val blockId = iterator.next - val locations = blockLocations.get(blockId) - locations -= blockManagerId - if (locations.size == 0) { - blockLocations.remove(locations) - } - } - } - - private def expireDeadHosts() { - logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.") - val now = System.currentTimeMillis() - val minSeenTime = now - slaveTimeout - val toRemove = new mutable.HashSet[BlockManagerId] - for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime) { - logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + - (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") - toRemove += info.blockManagerId - } - } - toRemove.foreach(removeBlockManager) - } - - private def removeExecutor(execId: String) { - logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") - blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) - } - - private def heartBeat(blockManagerId: BlockManagerId): Boolean = { - if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerId.executorId == "" && !isLocal - } else { - blockManagerInfo(blockManagerId).updateLastSeenMs() - true - } - } - - // Remove a block from the slaves that have it. This can only be used to remove - // blocks that the master knows about. - private def removeBlockFromWorkers(blockId: String) { - val locations = blockLocations.get(blockId) - if (locations != null) { - locations.foreach { blockManagerId: BlockManagerId => - val blockManager = blockManagerInfo.get(blockManagerId) - if (blockManager.isDefined) { - // Remove the block from the slave's BlockManager. - // Doesn't actually wait for a confirmation and the message might get lost. - // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor ! RemoveBlock(blockId) - } - } - } - } - - // Return a map from the block manager id to max memory and remaining memory. - private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { - blockManagerInfo.map { case(blockManagerId, info) => - (blockManagerId, (info.maxMem, info.remainingMem)) - }.toMap - } - - private def storageStatus: Array[StorageStatus] = { - blockManagerInfo.map { case(blockManagerId, info) => - import collection.JavaConverters._ - StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) - }.toArray - } - - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - if (id.executorId == "" && !isLocal) { - // Got a register message from the master node; don't register it - } else if (!blockManagerInfo.contains(id)) { - blockManagerIdByExecutor.get(id.executorId) match { - case Some(manager) => - // A block manager of the same executor already exists. - // This should never happen. Let's just quit. - logError("Got two different block manager registrations on " + id.executorId) - System.exit(1) - case None => - blockManagerIdByExecutor(id.executorId) = id - } - blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) - } - } - - private def updateBlockInfo( - blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long) { - - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.executorId == "" && !isLocal) { - // We intentionally do not register the master (except in local mode), - // so we should not indicate failure. - sender ! true - } else { - sender ! false - } - return - } - - if (blockId == null) { - blockManagerInfo(blockManagerId).updateLastSeenMs() - sender ! true - return - } - - blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) - - var locations: mutable.HashSet[BlockManagerId] = null - if (blockLocations.containsKey(blockId)) { - locations = blockLocations.get(blockId) - } else { - locations = new mutable.HashSet[BlockManagerId] - blockLocations.put(blockId, locations) - } - - if (storageLevel.isValid) { - locations.add(blockManagerId) - } else { - locations.remove(blockManagerId) - } - - // Remove the block from master tracking if it has been removed on all slaves. - if (locations.size == 0) { - blockLocations.remove(blockId) - } - sender ! true - } - - private def getLocations(blockId: String): Seq[BlockManagerId] = { - if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty - } - - private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - blockIds.map(blockId => getLocations(blockId)) - } - - private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = { - val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - - val selfIndex = peers.indexOf(blockManagerId) - if (selfIndex == -1) { - throw new SparkException("Self index for " + blockManagerId + " not found") - } - - // Note that this logic will select the same node multiple times if there aren't enough peers - Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq - } -} - - -private[spark] -object BlockManagerMasterActor { - - case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) - - class BlockManagerInfo( - val blockManagerId: BlockManagerId, - timeMs: Long, - val maxMem: Long, - val slaveActor: ActorRef) - extends Logging { - - private var _lastSeenMs: Long = timeMs - private var _remainingMem: Long = maxMem - - // Mapping from block id to its status. - private val _blocks = new JHashMap[String, BlockStatus] - - logInfo("Registering block manager %s with %s RAM".format( - blockManagerId.hostPort, Utils.bytesToString(maxMem))) - - def updateLastSeenMs() { - _lastSeenMs = System.currentTimeMillis() - } - - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, - diskSize: Long) { - - updateLastSeenMs() - - if (_blocks.containsKey(blockId)) { - // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel - - if (originalLevel.useMemory) { - _remainingMem += memSize - } - } - - if (storageLevel.isValid) { - // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) - if (storageLevel.useMemory) { - _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) - } - if (storageLevel.useDisk) { - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) - } - } else if (_blocks.containsKey(blockId)) { - // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) - _blocks.remove(blockId) - if (blockStatus.storageLevel.useMemory) { - _remainingMem += blockStatus.memSize - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) - } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) - } - } - } - - def removeBlock(blockId: String) { - if (_blocks.containsKey(blockId)) { - _remainingMem += _blocks.get(blockId).memSize - _blocks.remove(blockId) - } - } - - def remainingMem: Long = _remainingMem - - def lastSeenMs: Long = _lastSeenMs - - def blocks: JHashMap[String, BlockStatus] = _blocks - - override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem - - def clear() { - _blocks.clear() - } - } -} diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala deleted file mode 100644 index 9375a9ca54..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.io.{Externalizable, ObjectInput, ObjectOutput} - -import akka.actor.ActorRef - - -private[storage] object BlockManagerMessages { - ////////////////////////////////////////////////////////////////////////////////// - // Messages from the master to slaves. - ////////////////////////////////////////////////////////////////////////////////// - sealed trait ToBlockManagerSlave - - // Remove a block from the slaves that have it. This can only be used to remove - // blocks that the master knows about. - case class RemoveBlock(blockId: String) extends ToBlockManagerSlave - - // Remove all blocks belonging to a specific RDD. - case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave - - - ////////////////////////////////////////////////////////////////////////////////// - // Messages from slaves to the master. - ////////////////////////////////////////////////////////////////////////////////// - sealed trait ToBlockManagerMaster - - case class RegisterBlockManager( - blockManagerId: BlockManagerId, - maxMemSize: Long, - sender: ActorRef) - extends ToBlockManagerMaster - - case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - - class UpdateBlockInfo( - var blockManagerId: BlockManagerId, - var blockId: String, - var storageLevel: StorageLevel, - var memSize: Long, - var diskSize: Long) - extends ToBlockManagerMaster - with Externalizable { - - def this() = this(null, null, null, 0, 0) // For deserialization only - - override def writeExternal(out: ObjectOutput) { - blockManagerId.writeExternal(out) - out.writeUTF(blockId) - storageLevel.writeExternal(out) - out.writeLong(memSize) - out.writeLong(diskSize) - } - - override def readExternal(in: ObjectInput) { - blockManagerId = BlockManagerId(in) - blockId = in.readUTF() - storageLevel = StorageLevel(in) - memSize = in.readLong() - diskSize = in.readLong() - } - } - - object UpdateBlockInfo { - def apply(blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long): UpdateBlockInfo = { - new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize) - } - - // For pattern-matching - def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) - } - } - - case class GetLocations(blockId: String) extends ToBlockManagerMaster - - case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster - - case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster - - case class RemoveExecutor(execId: String) extends ToBlockManagerMaster - - case object StopBlockManagerMaster extends ToBlockManagerMaster - - case object GetMemoryStatus extends ToBlockManagerMaster - - case object ExpireDeadHosts extends ToBlockManagerMaster - - case object GetStorageStatus extends ToBlockManagerMaster -} diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala deleted file mode 100644 index 6e5fb43732..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import akka.actor.Actor - -import spark.storage.BlockManagerMessages._ - - -/** - * An actor to take commands from the master to execute options. For example, - * this is used to remove blocks from the slave's BlockManager. - */ -class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { - override def receive = { - - case RemoveBlock(blockId) => - blockManager.removeBlock(blockId) - - case RemoveRdd(rddId) => - val numBlocksRemoved = blockManager.removeRdd(rddId) - sender ! numBlocksRemoved - } -} diff --git a/core/src/main/scala/spark/storage/BlockManagerSource.scala b/core/src/main/scala/spark/storage/BlockManagerSource.scala deleted file mode 100644 index 2aecd1ea71..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerSource.scala +++ /dev/null @@ -1,48 +0,0 @@ -package spark.storage - -import com.codahale.metrics.{Gauge,MetricRegistry} - -import spark.metrics.source.Source - - -private[spark] class BlockManagerSource(val blockManager: BlockManager) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "BlockManager" - - metricRegistry.register(MetricRegistry.name("memory", "maxMem", "MBytes"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) - maxMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "remainingMem", "MBytes"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) - remainingMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "memUsed", "MBytes"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) - (maxMem - remainingMem) / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed", "MBytes"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val diskSpaceUsed = storageStatusList - .flatMap(_.blocks.values.map(_.diskSize)) - .reduceOption(_ + _) - .getOrElse(0L) - - diskSpaceUsed / 1024 / 1024 - } - }) -} diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala deleted file mode 100644 index 39064bce92..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer - -import spark.{Logging, Utils} -import spark.network._ - -/** - * A network interface for BlockManager. Each slave should have one - * BlockManagerWorker. - * - * TODO: Use event model. - */ -private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { - initLogging() - - blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) - - def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => { - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - return Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => logError("Exception handling buffer message", e) - return None - } - } - case otherMessage: Any => { - logError("Unknown type message received: " + otherMessage) - return None - } - } - } - - def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logDebug("Received [" + pB + "]") - putBlock(pB.id, pB.data, pB.level) - return None - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - logDebug("Received [" + gB + "]") - val buffer = getBlock(gB.id) - if (buffer == null) { - return None - } - return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) - } - case _ => return None - } - } - - private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) - blockManager.putBytes(id, bytes, level) - logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.limit) - } - - private def getBlock(id: String): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("GetBlock " + id + " started from " + startTimeMs) - val buffer = blockManager.getLocalBytes(id) match { - case Some(bytes) => bytes - case None => null - } - logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - return buffer - } -} - -private[spark] object BlockManagerWorker extends Logging { - private var blockManagerWorker: BlockManagerWorker = null - - initLogging() - - def startBlockManagerWorker(manager: BlockManager) { - blockManagerWorker = new BlockManagerWorker(manager) - } - - def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val blockMessage = BlockMessage.fromPutBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) - return (resultMessage != None) - } - - def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val blockMessage = BlockMessage.fromGetBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) - responseMessage match { - case Some(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - logDebug("Response message received " + bufferMessage) - BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { - logDebug("Found " + blockMessage) - return blockMessage.getData - }) - } - case None => logDebug("No response message received"); return null - } - return null - } -} diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala deleted file mode 100644 index bcce26b7c1..0000000000 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer - -import scala.collection.mutable.StringBuilder -import scala.collection.mutable.ArrayBuffer - -import spark.network._ - -private[spark] case class GetBlock(id: String) -private[spark] case class GotBlock(id: String, data: ByteBuffer) -private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) - -private[spark] class BlockMessage() { - // Un-initialized: typ = 0 - // GetBlock: typ = 1 - // GotBlock: typ = 2 - // PutBlock: typ = 3 - private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: String = null - private var data: ByteBuffer = null - private var level: StorageLevel = null - - def set(getBlock: GetBlock) { - typ = BlockMessage.TYPE_GET_BLOCK - id = getBlock.id - } - - def set(gotBlock: GotBlock) { - typ = BlockMessage.TYPE_GOT_BLOCK - id = gotBlock.id - data = gotBlock.data - } - - def set(putBlock: PutBlock) { - typ = BlockMessage.TYPE_PUT_BLOCK - id = putBlock.id - data = putBlock.data - level = putBlock.level - } - - def set(buffer: ByteBuffer) { - val startTime = System.currentTimeMillis - /* - println() - println("BlockMessage: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ - typ = buffer.getInt() - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - id = idBuilder.toString() - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - - val booleanInt = buffer.getInt() - val replication = buffer.getInt() - level = StorageLevel(booleanInt, replication) - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } - - val finishTime = System.currentTimeMillis - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getType: Int = { - return typ - } - - def getId: String = { - return id - } - - def getData: ByteBuffer = { - return data - } - - def getLevel: StorageLevel = { - return level - } - - def toBufferMessage: BufferMessage = { - val startTime = System.currentTimeMillis - val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2) - buffer.putInt(typ).putInt(id.length()) - id.foreach((x: Char) => buffer.putChar(x)) - buffer.flip() - buffers += buffer - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) - buffer.flip() - buffers += buffer - - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } - - /* - println() - println("BlockMessage: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ - val finishTime = System.currentTimeMillis - return Message.createBufferMessage(buffers) - } - - override def toString: String = { - "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" - } -} - -private[spark] object BlockMessage { - val TYPE_NON_INITIALIZED: Int = 0 - val TYPE_GET_BLOCK: Int = 1 - val TYPE_GOT_BLOCK: Int = 2 - val TYPE_PUT_BLOCK: Int = 3 - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(bufferMessage) - newBlockMessage - } - - def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(buffer) - newBlockMessage - } - - def fromGetBlock(getBlock: GetBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(getBlock) - newBlockMessage - } - - def fromGotBlock(gotBlock: GotBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(gotBlock) - newBlockMessage - } - - def fromPutBlock(putBlock: PutBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(putBlock) - newBlockMessage - } - - def main(args: Array[String]) { - val B = new BlockMessage() - B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) - val bMsg = B.toBufferMessage - val C = new BlockMessage() - C.set(bMsg) - - println(B.getId + " " + B.getLevel) - println(C.getId + " " + C.getLevel) - } -} diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala deleted file mode 100644 index ee2fc167d5..0000000000 --- a/core/src/main/scala/spark/storage/BlockMessageArray.scala +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import spark._ -import spark.network._ - -private[spark] -class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging { - - def this(bm: BlockMessage) = this(Array(bm)) - - def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - - def apply(i: Int) = blockMessages(i) - - def iterator = blockMessages.iterator - - def length = blockMessages.length - - initLogging() - - def set(bufferMessage: BufferMessage) { - val startTime = System.currentTimeMillis - val newBlockMessages = new ArrayBuffer[BlockMessage]() - val buffer = bufferMessage.buffers(0) - buffer.clear() - /* - println() - println("BlockMessageArray: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ - while (buffer.remaining() > 0) { - val size = buffer.getInt() - logDebug("Creating block message of size " + size + " bytes") - val newBuffer = buffer.slice() - newBuffer.clear() - newBuffer.limit(size) - logDebug("Trying to convert buffer " + newBuffer + " to block message") - val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) - logDebug("Created " + newBlockMessage) - newBlockMessages += newBlockMessage - buffer.position(buffer.position() + size) - } - val finishTime = System.currentTimeMillis - logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s") - this.blockMessages = newBlockMessages - } - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - blockMessages.foreach(blockMessage => { - val bufferMessage = blockMessage.toBufferMessage - logDebug("Adding " + blockMessage) - val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) - sizeBuffer.flip - buffers += sizeBuffer - buffers ++= bufferMessage.buffers - logDebug("Added " + bufferMessage) - }) - - logDebug("Buffer list:") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessageArray: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ - return Message.createBufferMessage(buffers) - } -} - -private[spark] object BlockMessageArray { - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { - val newBlockMessageArray = new BlockMessageArray() - newBlockMessageArray.set(bufferMessage) - newBlockMessageArray - } - - def main(args: Array[String]) { - val blockMessages = - (0 until 10).map { i => - if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear - BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER)) - } else { - BlockMessage.fromGetBlock(GetBlock(i.toString)) - } - } - val blockMessageArray = new BlockMessageArray(blockMessages) - println("Block message array created") - - val bufferMessage = blockMessageArray.toBufferMessage - println("Converted to buffer message") - - val totalSize = bufferMessage.size - val newBuffer = ByteBuffer.allocate(totalSize) - newBuffer.clear() - bufferMessage.buffers.foreach(buffer => { - assert (0 == buffer.position()) - newBuffer.put(buffer) - buffer.rewind() - }) - newBuffer.flip - val newBufferMessage = Message.createBufferMessage(newBuffer) - println("Copied to new buffer message, size = " + newBufferMessage.size) - - val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - println("Converted back to block message array") - newBlockMessageArray.foreach(blockMessage => { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - println(pB) - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - println(gB) - } - } - }) - } -} - - diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala deleted file mode 100644 index 3812009ca1..0000000000 --- a/core/src/main/scala/spark/storage/BlockObjectWriter.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - - -/** - * An interface for writing JVM objects to some underlying storage. This interface allows - * appending data to an existing block, and can guarantee atomicity in the case of faults - * as it allows the caller to revert partial writes. - * - * This interface does not support concurrent writes. - */ -abstract class BlockObjectWriter(val blockId: String) { - - var closeEventHandler: () => Unit = _ - - def open(): BlockObjectWriter - - def close() { - closeEventHandler() - } - - def isOpen: Boolean - - def registerCloseEventHandler(handler: () => Unit) { - closeEventHandler = handler - } - - /** - * Flush the partial writes and commit them as a single atomic block. Return the - * number of bytes written for this commit. - */ - def commit(): Long - - /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. - */ - def revertPartialWrites() - - /** - * Writes an object. - */ - def write(value: Any) - - /** - * Size of the valid writes, in bytes. - */ - def size(): Long -} diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala deleted file mode 100644 index c8db0022b0..0000000000 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer - -import spark.Logging - -/** - * Abstract class to store blocks - */ -private[spark] -abstract class BlockStore(val blockManager: BlockManager) extends Logging { - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) - - /** - * Put in a block and, possibly, also return its content as either bytes or another Iterator. - * This is used to efficiently write the values to multiple locations (e.g. for replication). - * - * @return a PutResult that contains the size of the data, as well as the values put if - * returnValues is true (if not, the result's data field can be null) - */ - def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, - returnValues: Boolean) : PutResult - - /** - * Return the size of a block in bytes. - */ - def getSize(blockId: String): Long - - def getBytes(blockId: String): Option[ByteBuffer] - - def getValues(blockId: String): Option[Iterator[Any]] - - /** - * Remove a block, if it exists. - * @param blockId the block to remove. - * @return True if the block was found and removed, False otherwise. - */ - def remove(blockId: String): Boolean - - def contains(blockId: String): Boolean - - def clear() { } -} diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala deleted file mode 100644 index b14497157e..0000000000 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ /dev/null @@ -1,329 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} -import java.nio.ByteBuffer -import java.nio.channels.FileChannel -import java.nio.channels.FileChannel.MapMode -import java.util.{Random, Date} -import java.text.SimpleDateFormat - -import scala.collection.mutable.ArrayBuffer - -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - -import spark.Utils -import spark.executor.ExecutorExitCode -import spark.serializer.{Serializer, SerializationStream} -import spark.Logging -import spark.network.netty.ShuffleSender -import spark.network.netty.PathResolver - - -/** - * Stores BlockManager blocks on disk. - */ -private class DiskStore(blockManager: BlockManager, rootDirs: String) - extends BlockStore(blockManager) with Logging { - - class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) - extends BlockObjectWriter(blockId) { - - private val f: File = createFile(blockId /*, allowAppendExisting */) - - // The file channel, used for repositioning / truncating the file. - private var channel: FileChannel = null - private var bs: OutputStream = null - private var objOut: SerializationStream = null - private var lastValidPosition = 0L - private var initialized = false - - override def open(): DiskBlockObjectWriter = { - val fos = new FileOutputStream(f, true) - channel = fos.getChannel() - bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) - initialized = true - this - } - - override def close() { - if (initialized) { - objOut.close() - channel = null - bs = null - objOut = null - } - // Invoke the close callback handler. - super.close() - } - - override def isOpen: Boolean = objOut != null - - // Flush the partial writes, and set valid length to be the length of the entire file. - // Return the number of bytes written for this commit. - override def commit(): Long = { - if (initialized) { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos - } else { - // lastValidPosition is zero if stream is uninitialized - lastValidPosition - } - } - - override def revertPartialWrites() { - if (initialized) { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) - } - } - - override def write(value: Any) { - if (!initialized) { - open() - } - objOut.writeObject(value) - } - - override def size(): Long = lastValidPosition - } - - private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt - - private var shuffleSender : ShuffleSender = null - // Create one local directory for each path mentioned in spark.local.dir; then, inside this - // directory, create multiple subdirectories that we will hash files into, in order to avoid - // having really large inodes at the top level. - private val localDirs: Array[File] = createLocalDirs() - private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - - addShutdownHook() - - def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) - : BlockObjectWriter = { - new DiskBlockObjectWriter(blockId, serializer, bufferSize) - } - - override def getSize(blockId: String): Long = { - getFile(blockId).length() - } - - override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { - // So that we do not modify the input offsets ! - // duplicate does not copy buffer, so inexpensive - val bytes = _bytes.duplicate() - logDebug("Attempting to put block " + blockId) - val startTime = System.currentTimeMillis - val file = createFile(blockId) - val channel = new RandomAccessFile(file, "rw").getChannel() - while (bytes.remaining > 0) { - channel.write(bytes) - } - channel.close() - val finishTime = System.currentTimeMillis - logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime))) - } - - private def getFileBytes(file: File): ByteBuffer = { - val length = file.length() - val channel = new RandomAccessFile(file, "r").getChannel() - val buffer = try { - channel.map(MapMode.READ_ONLY, 0, length) - } finally { - channel.close() - } - - buffer - } - - override def putValues( - blockId: String, - values: ArrayBuffer[Any], - level: StorageLevel, - returnValues: Boolean) - : PutResult = { - - logDebug("Attempting to write values for block " + blockId) - val startTime = System.currentTimeMillis - val file = createFile(blockId) - val fileOut = blockManager.wrapForCompression(blockId, - new FastBufferedOutputStream(new FileOutputStream(file))) - val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut) - objOut.writeAll(values.iterator) - objOut.close() - val length = file.length() - - val timeTaken = System.currentTimeMillis - startTime - logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.bytesToString(length), timeTaken)) - - if (returnValues) { - // Return a byte buffer for the contents of the file - val buffer = getFileBytes(file) - PutResult(length, Right(buffer)) - } else { - PutResult(length, null) - } - } - - override def getBytes(blockId: String): Option[ByteBuffer] = { - val file = getFile(blockId) - val bytes = getFileBytes(file) - Some(bytes) - } - - override def getValues(blockId: String): Option[Iterator[Any]] = { - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) - } - - /** - * A version of getValues that allows a custom serializer. This is used as part of the - * shuffle short-circuit code. - */ - def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) - } - - override def remove(blockId: String): Boolean = { - val file = getFile(blockId) - if (file.exists()) { - file.delete() - } else { - false - } - } - - override def contains(blockId: String): Boolean = { - getFile(blockId).exists() - } - - private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { - val file = getFile(blockId) - if (!allowAppendExisting && file.exists()) { - // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task - // was rescheduled on the same machine as the old task. - logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") - file.delete() - } - file - } - - private def getFile(blockId: String): File = { - logDebug("Getting file for block " + blockId) - - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = math.abs(blockId.hashCode) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - - // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - newDir.mkdir() - subDirs(dirId)(subDirId) = newDir - newDir - } - } - } - - new File(subDir, blockId) - } - - private def createLocalDirs(): Array[File] = { - logDebug("Creating local directories at root dirs '" + rootDirs + "'") - val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map { rootDir => - var foundLocalDir = false - var localDir: File = null - var localDirId: String = null - var tries = 0 - val rand = new Random() - while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - localDir = new File(rootDir, "spark-local-" + localDirId) - if (!localDir.exists) { - foundLocalDir = localDir.mkdirs() - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) - } - } - if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + - " attempts to create local dir in " + rootDir) - System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) - } - logInfo("Created local directory at " + localDir) - localDir - } - } - - private def addShutdownHook() { - localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { - override def run() { - logDebug("Shutdown hook called") - localDirs.foreach { localDir => - try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) - } catch { - case t: Throwable => - logError("Exception while deleting local spark dir: " + localDir, t) - } - } - if (shuffleSender != null) { - shuffleSender.stop - } - } - }) - } - - private[storage] def startShuffleBlockSender(port: Int): Int = { - val pResolver = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!blockId.startsWith("shuffle_")) { - return null - } - DiskStore.this.getFile(blockId).getAbsolutePath() - } - } - shuffleSender = new ShuffleSender(port, pResolver) - logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port) - shuffleSender.port - } -} diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala deleted file mode 100644 index 5a51f5cf31..0000000000 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.util.LinkedHashMap -import java.util.concurrent.ArrayBlockingQueue -import spark.{SizeEstimator, Utils} -import java.nio.ByteBuffer -import collection.mutable.ArrayBuffer - -/** - * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as - * serialized ByteBuffers. - */ -private class MemoryStore(blockManager: BlockManager, maxMemory: Long) - extends BlockStore(blockManager) { - - case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false) - - private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) - private var currentMemory = 0L - // Object used to ensure that only one thread is putting blocks and if necessary, dropping - // blocks from the memory store. - private val putLock = new Object() - - logInfo("MemoryStore started with capacity %s.".format(Utils.bytesToString(maxMemory))) - - def freeMemory: Long = maxMemory - currentMemory - - override def getSize(blockId: String): Long = { - entries.synchronized { - entries.get(blockId).size - } - } - - override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { - // Work on a duplicate - since the original input might be used elsewhere. - val bytes = _bytes.duplicate() - bytes.rewind() - if (level.deserialized) { - val values = blockManager.dataDeserialize(blockId, bytes) - val elements = new ArrayBuffer[Any] - elements ++= values - val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) - tryToPut(blockId, elements, sizeEstimate, true) - } else { - tryToPut(blockId, bytes, bytes.limit, false) - } - } - - override def putValues( - blockId: String, - values: ArrayBuffer[Any], - level: StorageLevel, - returnValues: Boolean) - : PutResult = { - - if (level.deserialized) { - val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) - tryToPut(blockId, values, sizeEstimate, true) - PutResult(sizeEstimate, Left(values.iterator)) - } else { - val bytes = blockManager.dataSerialize(blockId, values.iterator) - tryToPut(blockId, bytes, bytes.limit, false) - PutResult(bytes.limit(), Right(bytes.duplicate())) - } - } - - override def getBytes(blockId: String): Option[ByteBuffer] = { - val entry = entries.synchronized { - entries.get(blockId) - } - if (entry == null) { - None - } else if (entry.deserialized) { - Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)) - } else { - Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data - } - } - - override def getValues(blockId: String): Option[Iterator[Any]] = { - val entry = entries.synchronized { - entries.get(blockId) - } - if (entry == null) { - None - } else if (entry.deserialized) { - Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator) - } else { - val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data - Some(blockManager.dataDeserialize(blockId, buffer)) - } - } - - override def remove(blockId: String): Boolean = { - entries.synchronized { - val entry = entries.get(blockId) - if (entry != null) { - entries.remove(blockId) - currentMemory -= entry.size - logInfo("Block %s of size %d dropped from memory (free %d)".format( - blockId, entry.size, freeMemory)) - true - } else { - false - } - } - } - - override def clear() { - entries.synchronized { - entries.clear() - } - logInfo("MemoryStore cleared") - } - - /** - * Return the RDD ID that a given block ID is from, or null if it is not an RDD block. - */ - private def getRddId(blockId: String): String = { - if (blockId.startsWith("rdd_")) { - blockId.split('_')(1) - } else { - null - } - } - - /** - * Try to put in a set of values, if we can free up enough space. The value should either be - * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) - * size must also be passed by the caller. - * - * Locks on the object putLock to ensure that all the put requests and its associated block - * dropping is done by only on thread at a time. Otherwise while one thread is dropping - * blocks to free memory for one block, another thread may use up the freed space for - * another block. - */ - private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { - // TODO: Its possible to optimize the locking by locking entries only when selecting blocks - // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been - // released, it must be ensured that those to-be-dropped blocks are not double counted for - // freeing up more space for another block that needs to be put. Only then the actually dropping - // of blocks (and writing to disk if necessary) can proceed in parallel. - putLock.synchronized { - if (ensureFreeSpace(blockId, size)) { - val entry = new Entry(value, size, deserialized) - entries.synchronized { entries.put(blockId, entry) } - currentMemory += size - if (deserialized) { - logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - } else { - logInfo("Block %s stored as bytes to memory (size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - } - true - } else { - // Tell the block manager that we couldn't put it in memory so that it can drop it to - // disk if the block allows disk storage. - val data = if (deserialized) { - Left(value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(value.asInstanceOf[ByteBuffer].duplicate()) - } - blockManager.dropFromMemory(blockId, data) - false - } - } - } - - /** - * Tries to free up a given amount of space to store a particular block, but can fail and return - * false if either the block is bigger than our memory or it would require replacing another - * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that - * don't fit into memory that we want to avoid). - * - * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. - * Otherwise, the freed space may fill up before the caller puts in their new value. - */ - private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { - - logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( - space, currentMemory, maxMemory)) - - if (space > maxMemory) { - logInfo("Will not store " + blockIdToAdd + " as it is larger than our memory limit") - return false - } - - if (maxMemory - currentMemory < space) { - val rddToAdd = getRddId(blockIdToAdd) - val selectedBlocks = new ArrayBuffer[String]() - var selectedMemory = 0L - - // This is synchronized to ensure that the set of entries is not changed - // (because of getValue or getBytes) while traversing the iterator, as that - // can lead to exceptions. - entries.synchronized { - val iterator = entries.entrySet().iterator() - while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { - val pair = iterator.next() - val blockId = pair.getKey - if (rddToAdd != null && rddToAdd == getRddId(blockId)) { - logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + - "block from the same RDD") - return false - } - selectedBlocks += blockId - selectedMemory += pair.getValue.size - } - } - - if (maxMemory - (currentMemory - selectedMemory) >= space) { - logInfo(selectedBlocks.size + " blocks selected for dropping") - for (blockId <- selectedBlocks) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping - // blocks and removing entries. However the check is still here for - // future safety. - if (entry != null) { - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) - } - blockManager.dropFromMemory(blockId, data) - } - } - return true - } else { - return false - } - } - return true - } - - override def contains(blockId: String): Boolean = { - entries.synchronized { entries.containsKey(blockId) } - } -} - diff --git a/core/src/main/scala/spark/storage/PutResult.scala b/core/src/main/scala/spark/storage/PutResult.scala deleted file mode 100644 index 3a0974fe15..0000000000 --- a/core/src/main/scala/spark/storage/PutResult.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer - -/** - * Result of adding a block into a BlockStore. Contains its estimated size, and possibly the - * values put if the caller asked for them to be returned (e.g. for chaining replication) - */ -private[spark] case class PutResult(size: Long, data: Either[Iterator[_], ByteBuffer]) diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala deleted file mode 100644 index 8a7a6f9ed3..0000000000 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import spark.serializer.Serializer - - -private[spark] -class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) - - -private[spark] -trait ShuffleBlocks { - def acquireWriters(mapId: Int): ShuffleWriterGroup - def releaseWriters(group: ShuffleWriterGroup) -} - - -private[spark] -class ShuffleBlockManager(blockManager: BlockManager) { - - def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = { - new ShuffleBlocks { - // Get a group of writers for a map task. - override def acquireWriters(mapId: Int): ShuffleWriterGroup = { - val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 - val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) - } - new ShuffleWriterGroup(mapId, writers) - } - - override def releaseWriters(group: ShuffleWriterGroup) = { - // Nothing really to release here. - } - } - } -} - - -private[spark] -object ShuffleBlockManager { - - // Returns the block id for a given shuffle block. - def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = { - "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId - } - - // Returns true if the block is a shuffle block. - def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_") -} diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala deleted file mode 100644 index f52650988c..0000000000 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} - -/** - * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, - * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory - * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. - * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. To create your own storage level object, use the factor method - * of the singleton object (`StorageLevel(...)`). - */ -class StorageLevel private( - private var useDisk_ : Boolean, - private var useMemory_ : Boolean, - private var deserialized_ : Boolean, - private var replication_ : Int = 1) - extends Externalizable { - - // TODO: Also add fields for caching priority, dataset ID, and flushing. - private def this(flags: Int, replication: Int) { - this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) - } - - def this() = this(false, true, false) // For deserialization - - def useDisk = useDisk_ - def useMemory = useMemory_ - def deserialized = deserialized_ - def replication = replication_ - - assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") - - override def clone(): StorageLevel = new StorageLevel( - this.useDisk, this.useMemory, this.deserialized, this.replication) - - override def equals(other: Any): Boolean = other match { - case s: StorageLevel => - s.useDisk == useDisk && - s.useMemory == useMemory && - s.deserialized == deserialized && - s.replication == replication - case _ => - false - } - - def isValid = ((useMemory || useDisk) && (replication > 0)) - - def toInt: Int = { - var ret = 0 - if (useDisk_) { - ret |= 4 - } - if (useMemory_) { - ret |= 2 - } - if (deserialized_) { - ret |= 1 - } - return ret - } - - override def writeExternal(out: ObjectOutput) { - out.writeByte(toInt) - out.writeByte(replication_) - } - - override def readExternal(in: ObjectInput) { - val flags = in.readByte() - useDisk_ = (flags & 4) != 0 - useMemory_ = (flags & 2) != 0 - deserialized_ = (flags & 1) != 0 - replication_ = in.readByte() - } - - @throws(classOf[IOException]) - private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) - - override def toString: String = - "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) - - override def hashCode(): Int = toInt * 41 + replication - def description : String = { - var result = "" - result += (if (useDisk) "Disk " else "") - result += (if (useMemory) "Memory " else "") - result += (if (deserialized) "Deserialized " else "Serialized") - result += "%sx Replicated".format(replication) - result - } -} - - -object StorageLevel { - val NONE = new StorageLevel(false, false, false) - val DISK_ONLY = new StorageLevel(true, false, false) - val DISK_ONLY_2 = new StorageLevel(true, false, false, 2) - val MEMORY_ONLY = new StorageLevel(false, true, true) - val MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2) - val MEMORY_ONLY_SER = new StorageLevel(false, true, false) - val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2) - val MEMORY_AND_DISK = new StorageLevel(true, true, true) - val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) - val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) - val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) - - /** Create a new StorageLevel object */ - def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) = - getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication)) - - /** Create a new StorageLevel object from its integer representation */ - def apply(flags: Int, replication: Int) = - getCachedStorageLevel(new StorageLevel(flags, replication)) - - /** Read StorageLevel object from ObjectInput stream */ - def apply(in: ObjectInput) = { - val obj = new StorageLevel() - obj.readExternal(in) - getCachedStorageLevel(obj) - } - - private[spark] - val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() - - private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { - storageLevelCache.putIfAbsent(level, level) - storageLevelCache.get(level) - } -} diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala deleted file mode 100644 index 123b8f6345..0000000000 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import spark.{Utils, SparkContext} -import BlockManagerMasterActor.BlockStatus - -private[spark] -case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, - blocks: Map[String, BlockStatus]) { - - def memUsed(blockPrefix: String = "") = { - blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). - reduceOption(_+_).getOrElse(0l) - } - - def diskUsed(blockPrefix: String = "") = { - blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize). - reduceOption(_+_).getOrElse(0l) - } - - def memRemaining : Long = maxMem - memUsed() - -} - -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) - extends Ordered[RDDInfo] { - override def toString = { - import Utils.bytesToString - "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id, - storageLevel.toString, numCachedPartitions, numPartitions, bytesToString(memSize), bytesToString(diskSize)) - } - - override def compare(that: RDDInfo) = { - this.id - that.id - } -} - -/* Helper methods for storage-related objects */ -private[spark] -object StorageUtils { - - /* Returns RDD-level information, compiled from a list of StorageStatus objects */ - def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus], - sc: SparkContext) : Array[RDDInfo] = { - rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) - } - - /* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */ - def blockLocationsFromStorageStatus(storageStatusList: Seq[StorageStatus]) = { - val blockLocationPairs = storageStatusList - .flatMap(s => s.blocks.map(b => (b._1, s.blockManagerId.hostPort))) - blockLocationPairs.groupBy(_._1).map{case (k, v) => (k, v.unzip._2)}.toMap - } - - /* Given a list of BlockStatus objets, returns information for each RDD */ - def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], - sc: SparkContext) : Array[RDDInfo] = { - - // Group by rddId, ignore the partition name - val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) => - k.substring(0,k.lastIndexOf('_')) - }.mapValues(_.values.toArray) - - // For each RDD, generate an RDDInfo object - val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) => - // Add up memory and disk sizes - val memSize = rddBlocks.map(_.memSize).reduce(_ + _) - val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) - - // Find the id of the RDD, e.g. rdd_1 => 1 - val rddId = rddKey.split("_").last.toInt - - // Get the friendly name and storage level for the RDD, if available - sc.persistentRdds.get(rddId).map { r => - val rddName = Option(r.name).getOrElse(rddKey) - val rddStorageLevel = r.getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) - } - }.flatten.toArray - - scala.util.Sorting.quickSort(rddInfos) - - rddInfos - } - - /* Removes all BlockStatus object that are not part of a block prefix */ - def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], - prefix: String) : Array[StorageStatus] = { - - storageStatusList.map { status => - val newBlocks = status.blocks.filterKeys(_.startsWith(prefix)) - //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _) - StorageStatus(status.blockManagerId, status.maxMem, newBlocks) - } - - } - -} diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala deleted file mode 100644 index b3ab1ff4b4..0000000000 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import akka.actor._ - -import spark.KryoSerializer -import java.util.concurrent.ArrayBlockingQueue -import util.Random - -/** - * This class tests the BlockManager and MemoryStore for thread safety and - * deadlocks. It spawns a number of producer and consumer threads. Producer - * threads continuously pushes blocks into the BlockManager and consumer - * threads continuously retrieves the blocks form the BlockManager and tests - * whether the block is correct or not. - */ -private[spark] object ThreadingTest { - - val numProducers = 5 - val numBlocksPerProducer = 20000 - - private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread { - val queue = new ArrayBlockingQueue[(String, Seq[Int])](100) - - override def run() { - for (i <- 1 to numBlocksPerProducer) { - val blockId = "b-" + id + "-" + i - val blockSize = Random.nextInt(1000) - val block = (1 to blockSize).map(_ => Random.nextInt()) - val level = randomLevel() - val startTime = System.currentTimeMillis() - manager.put(blockId, block.iterator, level, true) - println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") - queue.add((blockId, block)) - } - println("Producer thread " + id + " terminated") - } - - def randomLevel(): StorageLevel = { - math.abs(Random.nextInt()) % 4 match { - case 0 => StorageLevel.MEMORY_ONLY - case 1 => StorageLevel.MEMORY_ONLY_SER - case 2 => StorageLevel.MEMORY_AND_DISK - case 3 => StorageLevel.MEMORY_AND_DISK_SER - } - } - } - - private[spark] class ConsumerThread( - manager: BlockManager, - queue: ArrayBlockingQueue[(String, Seq[Int])] - ) extends Thread { - var numBlockConsumed = 0 - - override def run() { - println("Consumer thread started") - while(numBlockConsumed < numBlocksPerProducer) { - val (blockId, block) = queue.take() - val startTime = System.currentTimeMillis() - manager.get(blockId) match { - case Some(retrievedBlock) => - assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList, - "Block " + blockId + " did not match") - println("Got block " + blockId + " in " + - (System.currentTimeMillis - startTime) + " ms") - case None => - assert(false, "Block " + blockId + " could not be retrieved") - } - numBlockConsumed += 1 - } - println("Consumer thread terminated") - } - } - - def main(args: Array[String]) { - System.setProperty("spark.kryoserializer.buffer.mb", "1") - val actorSystem = ActorSystem("test") - val serializer = new KryoSerializer - val blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) - val blockManager = new BlockManager( - "", actorSystem, blockManagerMaster, serializer, 1024 * 1024) - val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) - val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) - producers.foreach(_.start) - consumers.foreach(_.start) - producers.foreach(_.join) - consumers.foreach(_.join) - blockManager.stop() - blockManagerMaster.stop() - actorSystem.shutdown() - actorSystem.awaitTermination() - println("Everything stopped.") - println( - "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.") - } -} diff --git a/core/src/main/scala/spark/ui/JettyUtils.scala b/core/src/main/scala/spark/ui/JettyUtils.scala deleted file mode 100644 index f66fe39905..0000000000 --- a/core/src/main/scala/spark/ui/JettyUtils.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -import javax.servlet.http.{HttpServletResponse, HttpServletRequest} - -import scala.annotation.tailrec -import scala.util.{Try, Success, Failure} -import scala.xml.Node - -import net.liftweb.json.{JValue, pretty, render} - -import org.eclipse.jetty.server.{Server, Request, Handler} -import org.eclipse.jetty.server.handler.{ResourceHandler, HandlerList, ContextHandler, AbstractHandler} -import org.eclipse.jetty.util.thread.QueuedThreadPool - -import spark.Logging - - -/** Utilities for launching a web server using Jetty's HTTP Server class */ -private[spark] object JettyUtils extends Logging { - // Base type for a function that returns something based on an HTTP request. Allows for - // implicit conversion from many types of functions to jetty Handlers. - type Responder[T] = HttpServletRequest => T - - // Conversions from various types of Responder's to jetty Handlers - implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler = - createHandler(responder, "text/json", (in: JValue) => pretty(render(in))) - - implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler = - createHandler(responder, "text/html", (in: Seq[Node]) => "" + in.toString) - - implicit def textResponderToHandler(responder: Responder[String]): Handler = - createHandler(responder, "text/plain") - - def createHandler[T <% AnyRef](responder: Responder[T], contentType: String, - extractFn: T => String = (in: Any) => in.toString): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, - response: HttpServletResponse) { - response.setContentType("%s;charset=utf-8".format(contentType)) - response.setStatus(HttpServletResponse.SC_OK) - baseRequest.setHandled(true) - val result = responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter().println(extractFn(result)) - } - } - } - - /** Creates a handler that always redirects the user to a given path */ - def createRedirectHandler(newPath: String): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, - response: HttpServletResponse) { - response.setStatus(302) - response.setHeader("Location", baseRequest.getRootURL + newPath) - baseRequest.setHandled(true) - } - } - } - - /** Creates a handler for serving files from a static directory */ - def createStaticHandler(resourceBase: String): ResourceHandler = { - val staticHandler = new ResourceHandler - Option(getClass.getClassLoader.getResource(resourceBase)) match { - case Some(res) => - staticHandler.setResourceBase(res.toString) - case None => - throw new Exception("Could not find resource path for Web UI: " + resourceBase) - } - staticHandler - } - - /** - * Attempts to start a Jetty server at the supplied ip:port which uses the supplied handlers. - * - * If the desired port number is contented, continues incrementing ports until a free port is - * found. Returns the chosen port and the jetty Server object. - */ - def startJettyServer(ip: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int) = { - val handlersToRegister = handlers.map { case(path, handler) => - val contextHandler = new ContextHandler(path) - contextHandler.setHandler(handler) - contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler] - } - - val handlerList = new HandlerList - handlerList.setHandlers(handlersToRegister.toArray) - - @tailrec - def connect(currentPort: Int): (Server, Int) = { - val server = new Server(currentPort) - val pool = new QueuedThreadPool - pool.setDaemon(true) - server.setThreadPool(pool) - server.setHandler(handlerList) - - Try { server.start() } match { - case s: Success[_] => - sys.addShutdownHook(server.stop()) // Be kind, un-bind - (server, server.getConnectors.head.getLocalPort) - case f: Failure[_] => - server.stop() - logInfo("Failed to create UI at port, %s. Trying again.".format(currentPort)) - logInfo("Error was: " + f.toString) - connect((currentPort + 1) % 65536) - } - } - - connect(port) - } -} diff --git a/core/src/main/scala/spark/ui/Page.scala b/core/src/main/scala/spark/ui/Page.scala deleted file mode 100644 index 87376a19d8..0000000000 --- a/core/src/main/scala/spark/ui/Page.scala +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -private[spark] object Page extends Enumeration { - val Stages, Storage, Environment, Executors = Value -} diff --git a/core/src/main/scala/spark/ui/SparkUI.scala b/core/src/main/scala/spark/ui/SparkUI.scala deleted file mode 100644 index 23ded44ba3..0000000000 --- a/core/src/main/scala/spark/ui/SparkUI.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} - -import spark.{Logging, SparkContext, SparkEnv, Utils} -import spark.ui.env.EnvironmentUI -import spark.ui.exec.ExecutorsUI -import spark.ui.storage.BlockManagerUI -import spark.ui.jobs.JobProgressUI -import spark.ui.JettyUtils._ - -/** Top level user interface for Spark */ -private[spark] class SparkUI(sc: SparkContext) extends Logging { - val host = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(Utils.localHostName()) - val port = Option(System.getProperty("spark.ui.port")).getOrElse(SparkUI.DEFAULT_PORT).toInt - var boundPort: Option[Int] = None - var server: Option[Server] = None - - val handlers = Seq[(String, Handler)]( - ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)), - ("/", createRedirectHandler("/stages")) - ) - val storage = new BlockManagerUI(sc) - val jobs = new JobProgressUI(sc) - val env = new EnvironmentUI(sc) - val exec = new ExecutorsUI(sc) - - // Add MetricsServlet handlers by default - val metricsServletHandlers = SparkEnv.get.metricsSystem.getServletHandlers - - val allHandlers = storage.getHandlers ++ jobs.getHandlers ++ env.getHandlers ++ - exec.getHandlers ++ metricsServletHandlers ++ handlers - - /** Bind the HTTP server which backs this web interface */ - def bind() { - try { - val (srv, usedPort) = JettyUtils.startJettyServer("0.0.0.0", port, allHandlers) - logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort)) - server = Some(srv) - boundPort = Some(usedPort) - } catch { - case e: Exception => - logError("Failed to create Spark JettyUtils", e) - System.exit(1) - } - } - - /** Initialize all components of the server */ - def start() { - // NOTE: This is decoupled from bind() because of the following dependency cycle: - // DAGScheduler() requires that the port of this server is known - // This server must register all handlers, including JobProgressUI, before binding - // JobProgressUI registers a listener with SparkContext, which requires sc to initialize - jobs.start() - exec.start() - } - - def stop() { - server.foreach(_.stop()) - } - - private[spark] def appUIAddress = "http://" + host + ":" + boundPort.getOrElse("-1") -} - -private[spark] object SparkUI { - val DEFAULT_PORT = "3030" - val STATIC_RESOURCE_DIR = "spark/ui/static" -} diff --git a/core/src/main/scala/spark/ui/UIUtils.scala b/core/src/main/scala/spark/ui/UIUtils.scala deleted file mode 100644 index 51bb18d888..0000000000 --- a/core/src/main/scala/spark/ui/UIUtils.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -import scala.xml.Node - -import spark.SparkContext - -/** Utility functions for generating XML pages with spark content. */ -private[spark] object UIUtils { - import Page._ - - /** Returns a spark page with correctly formatted headers */ - def headerSparkPage(content: => Seq[Node], sc: SparkContext, title: String, page: Page.Value) - : Seq[Node] = { - val jobs = page match { - case Stages =>
  • Stages
  • - case _ =>
  • Stages
  • - } - val storage = page match { - case Storage =>
  • Storage
  • - case _ =>
  • Storage
  • - } - val environment = page match { - case Environment =>
  • Environment
  • - case _ =>
  • Environment
  • - } - val executors = page match { - case Executors =>
  • Executors
  • - case _ =>
  • Executors
  • - } - - - - - - - - {sc.appName} - {title} - - - - -
    -
    -
    -

    - {title} -

    -
    -
    - {content} -
    - - - } - - /** Returns a page with the spark css/js and a simple format. Used for scheduler UI. */ - def basicSparkPage(content: => Seq[Node], title: String): Seq[Node] = { - - - - - - - {title} - - -
    -
    -
    -

    - - {title} -

    -
    -
    - {content} -
    - - - } - - /** Returns an HTML table constructed by generating a row for each object in a sequence. */ - def listingTable[T]( - headers: Seq[String], - makeRow: T => Seq[Node], - rows: Seq[T], - fixedWidth: Boolean = false): Seq[Node] = { - - val colWidth = 100.toDouble / headers.size - val colWidthAttr = if (fixedWidth) colWidth + "%" else "" - var tableClass = "table table-bordered table-striped table-condensed sortable" - if (fixedWidth) { - tableClass += " table-fixed" - } - - - {headers.map(h => )} - - {rows.map(r => makeRow(r))} - -
    {h}
    - } -} diff --git a/core/src/main/scala/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/spark/ui/UIWorkloadGenerator.scala deleted file mode 100644 index 5ff0572f0a..0000000000 --- a/core/src/main/scala/spark/ui/UIWorkloadGenerator.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -import scala.util.Random - -import spark.SparkContext -import spark.SparkContext._ -import spark.scheduler.cluster.SchedulingMode - - -/** - * Continuously generates jobs that expose various features of the WebUI (internal testing tool). - * - * Usage: ./run spark.ui.UIWorkloadGenerator [master] - */ -private[spark] object UIWorkloadGenerator { - val NUM_PARTITIONS = 100 - val INTER_JOB_WAIT_MS = 5000 - - def main(args: Array[String]) { - if (args.length < 2) { - println("usage: ./spark-class spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") - System.exit(1) - } - val master = args(0) - val schedulingMode = SchedulingMode.withName(args(1)) - val appName = "Spark UI Tester" - - if (schedulingMode == SchedulingMode.FAIR) { - System.setProperty("spark.cluster.schedulingmode", "FAIR") - } - val sc = new SparkContext(master, appName) - - def setProperties(s: String) = { - if(schedulingMode == SchedulingMode.FAIR) { - sc.setLocalProperty("spark.scheduler.cluster.fair.pool", s) - } - sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, s) - } - - val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) - def nextFloat() = (new Random()).nextFloat() - - val jobs = Seq[(String, () => Long)]( - ("Count", baseData.count), - ("Cache and Count", baseData.map(x => x).cache.count), - ("Single Shuffle", baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count), - ("Entirely failed phase", baseData.map(x => throw new Exception).count), - ("Partially failed phase", { - baseData.map{x => - val probFailure = (4.0 / NUM_PARTITIONS) - if (nextFloat() < probFailure) { - throw new Exception("This is a task failure") - } - 1 - }.count - }), - ("Partially failed phase (longer tasks)", { - baseData.map{x => - val probFailure = (4.0 / NUM_PARTITIONS) - if (nextFloat() < probFailure) { - Thread.sleep(100) - throw new Exception("This is a task failure") - } - 1 - }.count - }), - ("Job with delays", baseData.map(x => Thread.sleep(100)).count) - ) - - while (true) { - for ((desc, job) <- jobs) { - new Thread { - override def run() { - try { - setProperties(desc) - job() - println("Job funished: " + desc) - } catch { - case e: Exception => - println("Job Failed: " + desc) - } - } - }.start - Thread.sleep(INTER_JOB_WAIT_MS) - } - } - } -} diff --git a/core/src/main/scala/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/spark/ui/env/EnvironmentUI.scala deleted file mode 100644 index b1be1a27ef..0000000000 --- a/core/src/main/scala/spark/ui/env/EnvironmentUI.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.env - -import javax.servlet.http.HttpServletRequest - -import scala.collection.JavaConversions._ -import scala.util.Properties -import scala.xml.Node - -import org.eclipse.jetty.server.Handler - -import spark.ui.JettyUtils._ -import spark.ui.UIUtils -import spark.ui.Page.Environment -import spark.SparkContext - - -private[spark] class EnvironmentUI(sc: SparkContext) { - - def getHandlers = Seq[(String, Handler)]( - ("/environment", (request: HttpServletRequest) => envDetails(request)) - ) - - def envDetails(request: HttpServletRequest): Seq[Node] = { - val jvmInformation = Seq( - ("Java Version", "%s (%s)".format(Properties.javaVersion, Properties.javaVendor)), - ("Java Home", Properties.javaHome), - ("Scala Version", Properties.versionString), - ("Scala Home", Properties.scalaHome) - ).sorted - def jvmRow(kv: (String, String)) = {kv._1}{kv._2} - def jvmTable = - UIUtils.listingTable(Seq("Name", "Value"), jvmRow, jvmInformation, fixedWidth = true) - - val properties = System.getProperties.iterator.toSeq - val classPathProperty = properties.find { case (k, v) => - k.contains("java.class.path") - }.getOrElse(("", "")) - val sparkProperties = properties.filter(_._1.startsWith("spark")).sorted - val otherProperties = properties.diff(sparkProperties :+ classPathProperty).sorted - - val propertyHeaders = Seq("Name", "Value") - def propertyRow(kv: (String, String)) = {kv._1}{kv._2} - val sparkPropertyTable = - UIUtils.listingTable(propertyHeaders, propertyRow, sparkProperties, fixedWidth = true) - val otherPropertyTable = - UIUtils.listingTable(propertyHeaders, propertyRow, otherProperties, fixedWidth = true) - - val classPathEntries = classPathProperty._2 - .split(System.getProperty("path.separator", ":")) - .filterNot(e => e.isEmpty) - .map(e => (e, "System Classpath")) - val addedJars = sc.addedJars.iterator.toSeq.map{case (path, time) => (path, "Added By User")} - val addedFiles = sc.addedFiles.iterator.toSeq.map{case (path, time) => (path, "Added By User")} - val classPath = (addedJars ++ addedFiles ++ classPathEntries).sorted - - val classPathHeaders = Seq("Resource", "Source") - def classPathRow(data: (String, String)) = {data._1}{data._2} - val classPathTable = - UIUtils.listingTable(classPathHeaders, classPathRow, classPath, fixedWidth = true) - - val content = - -

    Runtime Information

    {jvmTable} -

    Spark Properties

    - {sparkPropertyTable} -

    System Properties

    - {otherPropertyTable} -

    Classpath Entries

    - {classPathTable} -
    - - UIUtils.headerSparkPage(content, sc, "Environment", Environment) - } -} diff --git a/core/src/main/scala/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/spark/ui/exec/ExecutorsUI.scala deleted file mode 100644 index 0a7021fbf8..0000000000 --- a/core/src/main/scala/spark/ui/exec/ExecutorsUI.scala +++ /dev/null @@ -1,136 +0,0 @@ -package spark.ui.exec - -import javax.servlet.http.HttpServletRequest - -import scala.collection.mutable.{HashMap, HashSet} -import scala.xml.Node - -import org.eclipse.jetty.server.Handler - -import spark.{ExceptionFailure, Logging, Utils, SparkContext} -import spark.executor.TaskMetrics -import spark.scheduler.cluster.TaskInfo -import spark.scheduler.{SparkListenerTaskStart, SparkListenerTaskEnd, SparkListener} -import spark.ui.JettyUtils._ -import spark.ui.Page.Executors -import spark.ui.UIUtils - - -private[spark] class ExecutorsUI(val sc: SparkContext) { - - private var _listener: Option[ExecutorsListener] = None - def listener = _listener.get - - def start() { - _listener = Some(new ExecutorsListener) - sc.addSparkListener(listener) - } - - def getHandlers = Seq[(String, Handler)]( - ("/executors", (request: HttpServletRequest) => render(request)) - ) - - def render(request: HttpServletRequest): Seq[Node] = { - val storageStatusList = sc.getExecutorStorageStatus - - val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_+_) - val memUsed = storageStatusList.map(_.memUsed()).fold(0L)(_+_) - val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_) - - val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used", - "Active tasks", "Failed tasks", "Complete tasks", "Total tasks") - - def execRow(kv: Seq[String]) = { - - {kv(0)} - {kv(1)} - {kv(2)} - - {Utils.bytesToString(kv(3).toLong)} / {Utils.bytesToString(kv(4).toLong)} - - - {Utils.bytesToString(kv(5).toLong)} - - {kv(6)} - {kv(7)} - {kv(8)} - {kv(9)} - - } - - val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b) - val execTable = UIUtils.listingTable(execHead, execRow, execInfo) - - val content = -
    -
    -
      -
    • Memory: - {Utils.bytesToString(memUsed)} Used - ({Utils.bytesToString(maxMem)} Total)
    • -
    • Disk: {Utils.bytesToString(diskSpaceUsed)} Used
    • -
    -
    -
    -
    -
    - {execTable} -
    -
    ; - - UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors) - } - - def getExecInfo(a: Int): Seq[String] = { - val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId - val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort - val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString - val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString - val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString - val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString - val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0) - val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0) - val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0) - val totalTasks = activeTasks + failedTasks + completedTasks - - Seq( - execId, - hostPort, - rddBlocks, - memUsed, - maxMem, - diskUsed, - activeTasks.toString, - failedTasks.toString, - completedTasks.toString, - totalTasks.toString - ) - } - - private[spark] class ExecutorsListener extends SparkListener with Logging { - val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]() - val executorToTasksComplete = HashMap[String, Int]() - val executorToTasksFailed = HashMap[String, Int]() - - override def onTaskStart(taskStart: SparkListenerTaskStart) { - val eid = taskStart.taskInfo.executorId - val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) - activeTasks += taskStart.taskInfo - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - val eid = taskEnd.taskInfo.executorId - val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) - activeTasks -= taskEnd.taskInfo - val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = - taskEnd.reason match { - case e: ExceptionFailure => - executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 - (Some(e), e.metrics) - case _ => - executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 - (None, Option(taskEnd.taskMetrics)) - } - } - } -} diff --git a/core/src/main/scala/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/spark/ui/jobs/IndexPage.scala deleted file mode 100644 index 8867a6c90c..0000000000 --- a/core/src/main/scala/spark/ui/jobs/IndexPage.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.jobs - -import javax.servlet.http.HttpServletRequest - -import scala.xml.{NodeSeq, Node} - -import spark.scheduler.cluster.SchedulingMode -import spark.ui.Page._ -import spark.ui.UIUtils._ - - -/** Page showing list of all ongoing and recently finished stages and pools*/ -private[spark] class IndexPage(parent: JobProgressUI) { - def listener = parent.listener - - def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - val activeStages = listener.activeStages.toSeq - val completedStages = listener.completedStages.reverse.toSeq - val failedStages = listener.failedStages.reverse.toSeq - val now = System.currentTimeMillis() - - var activeTime = 0L - for (tasks <- listener.stageToTasksActive.values; t <- tasks) { - activeTime += t.timeRunning(now) - } - - val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) - val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse, parent) - val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent) - - val pools = listener.sc.getAllPools - val poolTable = new PoolTable(pools, listener) - val summary: NodeSeq = -
    -
      -
    • - Total Duration: - {parent.formatDuration(now - listener.sc.startTime)} -
    • -
    • Scheduling Mode: {parent.sc.getSchedulingMode}
    • -
    • - Active Stages: - {activeStages.size} -
    • -
    • - Completed Stages: - {completedStages.size} -
    • -
    • - Failed Stages: - {failedStages.size} -
    • -
    -
    - - val content = summary ++ - {if (listener.sc.getSchedulingMode == SchedulingMode.FAIR) { -

    {pools.size} Fair Scheduler Pools

    ++ poolTable.toNodeSeq - } else { - Seq() - }} ++ -

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq++ -

    Completed Stages ({completedStages.size})

    ++ - completedStagesTable.toNodeSeq++ -

    Failed Stages ({failedStages.size})

    ++ - failedStagesTable.toNodeSeq - - headerSparkPage(content, parent.sc, "Spark Stages", Stages) - } - } -} diff --git a/core/src/main/scala/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/spark/ui/jobs/JobProgressListener.scala deleted file mode 100644 index 1d9767a83c..0000000000 --- a/core/src/main/scala/spark/ui/jobs/JobProgressListener.scala +++ /dev/null @@ -1,156 +0,0 @@ -package spark.ui.jobs - -import scala.Seq -import scala.collection.mutable.{ListBuffer, HashMap, HashSet} - -import spark.{ExceptionFailure, SparkContext, Success, Utils} -import spark.scheduler._ -import spark.scheduler.cluster.TaskInfo -import spark.executor.TaskMetrics -import collection.mutable - -/** - * Tracks task-level information to be displayed in the UI. - * - * All access to the data structures in this class must be synchronized on the - * class, since the UI thread and the DAGScheduler event loop may otherwise - * be reading/updating the internal data structures concurrently. - */ -private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener { - // How many stages to remember - val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt - val DEFAULT_POOL_NAME = "default" - - val stageToPool = new HashMap[Stage, String]() - val stageToDescription = new HashMap[Stage, String]() - val poolToActiveStages = new HashMap[String, HashSet[Stage]]() - - val activeStages = HashSet[Stage]() - val completedStages = ListBuffer[Stage]() - val failedStages = ListBuffer[Stage]() - - // Total metrics reflect metrics only for completed tasks - var totalTime = 0L - var totalShuffleRead = 0L - var totalShuffleWrite = 0L - - val stageToTime = HashMap[Int, Long]() - val stageToShuffleRead = HashMap[Int, Long]() - val stageToShuffleWrite = HashMap[Int, Long]() - val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]() - val stageToTasksComplete = HashMap[Int, Int]() - val stageToTasksFailed = HashMap[Int, Int]() - val stageToTaskInfos = - HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() - - override def onJobStart(jobStart: SparkListenerJobStart) {} - - override def onStageCompleted(stageCompleted: StageCompleted) = synchronized { - val stage = stageCompleted.stageInfo.stage - poolToActiveStages(stageToPool(stage)) -= stage - activeStages -= stage - completedStages += stage - trimIfNecessary(completedStages) - } - - /** If stages is too large, remove and garbage collect old stages */ - def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized { - if (stages.size > RETAINED_STAGES) { - val toRemove = RETAINED_STAGES / 10 - stages.takeRight(toRemove).foreach( s => { - stageToTaskInfos.remove(s.id) - stageToTime.remove(s.id) - stageToShuffleRead.remove(s.id) - stageToShuffleWrite.remove(s.id) - stageToTasksActive.remove(s.id) - stageToTasksComplete.remove(s.id) - stageToTasksFailed.remove(s.id) - stageToPool.remove(s) - if (stageToDescription.contains(s)) {stageToDescription.remove(s)} - }) - stages.trimEnd(toRemove) - } - } - - /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { - val stage = stageSubmitted.stage - activeStages += stage - - val poolName = Option(stageSubmitted.properties).map { - p => p.getProperty("spark.scheduler.cluster.fair.pool", DEFAULT_POOL_NAME) - }.getOrElse(DEFAULT_POOL_NAME) - stageToPool(stage) = poolName - - val description = Option(stageSubmitted.properties).flatMap { - p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) - } - description.map(d => stageToDescription(stage) = d) - - val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]()) - stages += stage - } - - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { - val sid = taskStart.task.stageId - val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) - tasksActive += taskStart.taskInfo - val taskList = stageToTaskInfos.getOrElse( - sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) - taskList += ((taskStart.taskInfo, None, None)) - stageToTaskInfos(sid) = taskList - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { - val sid = taskEnd.task.stageId - val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) - tasksActive -= taskEnd.taskInfo - val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = - taskEnd.reason match { - case e: ExceptionFailure => - stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1 - (Some(e), e.metrics) - case _ => - stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1 - (None, Option(taskEnd.taskMetrics)) - } - - stageToTime.getOrElseUpdate(sid, 0L) - val time = metrics.map(m => m.executorRunTime).getOrElse(0) - stageToTime(sid) += time - totalTime += time - - stageToShuffleRead.getOrElseUpdate(sid, 0L) - val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s => - s.remoteBytesRead).getOrElse(0L) - stageToShuffleRead(sid) += shuffleRead - totalShuffleRead += shuffleRead - - stageToShuffleWrite.getOrElseUpdate(sid, 0L) - val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s => - s.shuffleBytesWritten).getOrElse(0L) - stageToShuffleWrite(sid) += shuffleWrite - totalShuffleWrite += shuffleWrite - - val taskList = stageToTaskInfos.getOrElse( - sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) - taskList -= ((taskEnd.taskInfo, None, None)) - taskList += ((taskEnd.taskInfo, metrics, failureInfo)) - stageToTaskInfos(sid) = taskList - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { - jobEnd match { - case end: SparkListenerJobEnd => - end.jobResult match { - case JobFailed(ex, Some(stage)) => - activeStages -= stage - poolToActiveStages(stageToPool(stage)) -= stage - failedStages += stage - trimIfNecessary(failedStages) - case _ => - } - case _ => - } - } -} diff --git a/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala deleted file mode 100644 index c83f102ff3..0000000000 --- a/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.jobs - -import akka.util.Duration - -import java.text.SimpleDateFormat - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.Handler - -import scala.Seq -import scala.collection.mutable.{HashSet, ListBuffer, HashMap, ArrayBuffer} - -import spark.ui.JettyUtils._ -import spark.{ExceptionFailure, SparkContext, Success, Utils} -import spark.scheduler._ -import collection.mutable -import spark.scheduler.cluster.SchedulingMode -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - -/** Web UI showing progress status of all jobs in the given SparkContext. */ -private[spark] class JobProgressUI(val sc: SparkContext) { - private var _listener: Option[JobProgressListener] = None - def listener = _listener.get - val dateFmt = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - - private val indexPage = new IndexPage(this) - private val stagePage = new StagePage(this) - private val poolPage = new PoolPage(this) - - def start() { - _listener = Some(new JobProgressListener(sc)) - sc.addSparkListener(listener) - } - - def formatDuration(ms: Long) = Utils.msDurationToString(ms) - - def getHandlers = Seq[(String, Handler)]( - ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), - ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), - ("/stages", (request: HttpServletRequest) => indexPage.render(request)) - ) -} diff --git a/core/src/main/scala/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/spark/ui/jobs/PoolPage.scala deleted file mode 100644 index 7fb74dce40..0000000000 --- a/core/src/main/scala/spark/ui/jobs/PoolPage.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.ui.jobs - -import javax.servlet.http.HttpServletRequest - -import scala.xml.{NodeSeq, Node} -import scala.collection.mutable.HashSet - -import spark.scheduler.Stage -import spark.ui.UIUtils._ -import spark.ui.Page._ - -/** Page showing specific pool details */ -private[spark] class PoolPage(parent: JobProgressUI) { - def listener = parent.listener - - def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - val poolName = request.getParameter("poolname") - val poolToActiveStages = listener.poolToActiveStages - val activeStages = poolToActiveStages.get(poolName).toSeq.flatten - val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) - - val pool = listener.sc.getPoolForName(poolName).get - val poolTable = new PoolTable(Seq(pool), listener) - - val content =

    Summary

    ++ poolTable.toNodeSeq() ++ -

    {activeStages.size} Active Stages

    ++ activeStagesTable.toNodeSeq() - - headerSparkPage(content, parent.sc, "Fair Scheduler Pool: " + poolName, Stages) - } - } -} diff --git a/core/src/main/scala/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/spark/ui/jobs/PoolTable.scala deleted file mode 100644 index 621828f9c3..0000000000 --- a/core/src/main/scala/spark/ui/jobs/PoolTable.scala +++ /dev/null @@ -1,55 +0,0 @@ -package spark.ui.jobs - -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.xml.Node - -import spark.scheduler.Stage -import spark.scheduler.cluster.Schedulable - -/** Table showing list of pools */ -private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) { - - var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages - - def toNodeSeq(): Seq[Node] = { - listener.synchronized { - poolTable(poolRow, pools) - } - } - - private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node], - rows: Seq[Schedulable] - ): Seq[Node] = { - - - - - - - - - - - {rows.map(r => makeRow(r, poolToActiveStages))} - -
    Pool NameMinimum SharePool WeightActive StagesRunning TasksSchedulingMode
    - } - - private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]]) - : Seq[Node] = { - val activeStages = poolToActiveStages.get(p.name) match { - case Some(stages) => stages.size - case None => 0 - } - - {p.name} - {p.minShare} - {p.weight} - {activeStages} - {p.runningTasks} - {p.schedulingMode} - - } -} - diff --git a/core/src/main/scala/spark/ui/jobs/StagePage.scala b/core/src/main/scala/spark/ui/jobs/StagePage.scala deleted file mode 100644 index c2341475c7..0000000000 --- a/core/src/main/scala/spark/ui/jobs/StagePage.scala +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.jobs - -import java.util.Date - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import spark.ui.UIUtils._ -import spark.ui.Page._ -import spark.util.Distribution -import spark.{ExceptionFailure, Utils} -import spark.scheduler.cluster.TaskInfo -import spark.executor.TaskMetrics - -/** Page showing statistics and task list for a given stage */ -private[spark] class StagePage(parent: JobProgressUI) { - def listener = parent.listener - val dateFmt = parent.dateFmt - - def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - val stageId = request.getParameter("id").toInt - val now = System.currentTimeMillis() - - if (!listener.stageToTaskInfos.contains(stageId)) { - val content = -
    -

    Summary Metrics

    No tasks have started yet -

    Tasks

    No tasks have started yet -
    - return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages) - } - - val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime) - - val numCompleted = tasks.count(_._1.finished) - val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L) - val hasShuffleRead = shuffleReadBytes > 0 - val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L) - val hasShuffleWrite = shuffleWriteBytes > 0 - - var activeTime = 0L - listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) - - val summary = -
    -
      -
    • - CPU time: - {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)} -
    • - {if (hasShuffleRead) -
    • - Shuffle read: - {Utils.bytesToString(shuffleReadBytes)} -
    • - } - {if (hasShuffleWrite) -
    • - Shuffle write: - {Utils.bytesToString(shuffleWriteBytes)} -
    • - } -
    -
    - - val taskHeaders: Seq[String] = - Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++ - Seq("GC Time") ++ - {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ - {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++ - Seq("Errors") - - val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks) - - // Excludes tasks which failed and have incomplete metrics - val validTasks = tasks.filter(t => t._1.status == "SUCCESS" && (t._2.isDefined)) - - val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { - None - } - else { - val serviceTimes = validTasks.map{case (info, metrics, exception) => - metrics.get.executorRunTime.toDouble} - val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map( - ms => parent.formatDuration(ms.toLong)) - - def getQuantileCols(data: Seq[Double]) = - Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong)) - - val shuffleReadSizes = validTasks.map { - case(info, metrics, exception) => - metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble - } - val shuffleReadQuantiles = "Shuffle Read (Remote)" +: getQuantileCols(shuffleReadSizes) - - val shuffleWriteSizes = validTasks.map { - case(info, metrics, exception) => - metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble - } - val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes) - - val listings: Seq[Seq[String]] = Seq(serviceQuantiles, - if (hasShuffleRead) shuffleReadQuantiles else Nil, - if (hasShuffleWrite) shuffleWriteQuantiles else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", - "Median", "75th percentile", "Max") - def quantileRow(data: Seq[String]): Seq[Node] = {data.map(d => {d})} - Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) - } - - val content = - summary ++ -

    Summary Metrics for {numCompleted} Completed Tasks

    ++ -
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ -

    Tasks

    ++ taskTable; - - headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) - } - } - - - def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean) - (taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = { - def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] = - trace.map(e => {e.toString}) - val (info, metrics, exception) = taskData - - val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis()) - else metrics.map(m => m.executorRunTime).getOrElse(1) - val formatDuration = if (info.status == "RUNNING") parent.formatDuration(duration) - else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") - val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) - - - {info.taskId} - {info.status} - {info.taskLocality} - {info.host} - {dateFmt.format(new Date(info.launchTime))} - - {formatDuration} - - - {if (gcTime > 0) parent.formatDuration(gcTime) else ""} - - {if (shuffleRead) { - {metrics.flatMap{m => m.shuffleReadMetrics}.map{s => - Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")} - }} - {if (shuffleWrite) { - {metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")} - }} - {exception.map(e => - - {e.className} ({e.description})
    - {fmtStackTrace(e.stackTrace)} -
    ).getOrElse("")} - - - } -} diff --git a/core/src/main/scala/spark/ui/jobs/StageTable.scala b/core/src/main/scala/spark/ui/jobs/StageTable.scala deleted file mode 100644 index 2b1bc984fc..0000000000 --- a/core/src/main/scala/spark/ui/jobs/StageTable.scala +++ /dev/null @@ -1,107 +0,0 @@ -package spark.ui.jobs - -import java.util.Date - -import scala.xml.Node -import scala.collection.mutable.HashSet - -import spark.Utils -import spark.scheduler.cluster.{SchedulingMode, TaskInfo} -import spark.scheduler.Stage - - -/** Page showing list of all ongoing and recently finished stages */ -private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) { - - val listener = parent.listener - val dateFmt = parent.dateFmt - val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR - - def toNodeSeq(): Seq[Node] = { - listener.synchronized { - stageTable(stageRow, stages) - } - } - - /** Special table which merges two header cells. */ - private def stageTable[T](makeRow: T => Seq[Node], rows: Seq[T]): Seq[Node] = { - - - - {if (isFairScheduler) {} else {}} - - - - - - - - - {rows.map(r => makeRow(r))} - -
    Stage IdPool NameDescriptionSubmittedDurationTasks: Succeeded/TotalShuffle ReadShuffle Write
    - } - - private def makeProgressBar(started: Int, completed: Int, failed: String, total: Int): Seq[Node] = { - val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) - -
    - - {completed}/{total} {failed} - -
    -
    -
    - } - - - private def stageRow(s: Stage): Seq[Node] = { - val submissionTime = s.submissionTime match { - case Some(t) => dateFmt.format(new Date(t)) - case None => "Unknown" - } - - val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match { - case 0 => "" - case b => Utils.bytesToString(b) - } - val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match { - case 0 => "" - case b => Utils.bytesToString(b) - } - - val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size - val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0) - val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match { - case f if f > 0 => "(%s failed)".format(f) - case _ => "" - } - val totalTasks = s.numPartitions - - val poolName = listener.stageToPool.get(s) - - val nameLink = {s.name} - val description = listener.stageToDescription.get(s) - .map(d =>
    {d}
    {nameLink}
    ).getOrElse(nameLink) - val finishTime = s.completionTime.getOrElse(System.currentTimeMillis()) - val duration = s.submissionTime.map(t => finishTime - t) - - - {s.id} - {if (isFairScheduler) { - {poolName.get}} - } - {description} - {submissionTime} - - {duration.map(d => parent.formatDuration(d)).getOrElse("Unknown")} - - - {makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)} - - {shuffleRead} - {shuffleWrite} - - } -} diff --git a/core/src/main/scala/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/spark/ui/storage/BlockManagerUI.scala deleted file mode 100644 index 49ed069c75..0000000000 --- a/core/src/main/scala/spark/ui/storage/BlockManagerUI.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.storage - -import akka.util.Duration - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.Handler - -import spark.{Logging, SparkContext} -import spark.ui.JettyUtils._ - -/** Web UI showing storage status of all RDD's in the given SparkContext. */ -private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging { - implicit val timeout = Duration.create( - System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - - val indexPage = new IndexPage(this) - val rddPage = new RDDPage(this) - - def getHandlers = Seq[(String, Handler)]( - ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), - ("/storage", (request: HttpServletRequest) => indexPage.render(request)) - ) -} diff --git a/core/src/main/scala/spark/ui/storage/IndexPage.scala b/core/src/main/scala/spark/ui/storage/IndexPage.scala deleted file mode 100644 index fc6273c694..0000000000 --- a/core/src/main/scala/spark/ui/storage/IndexPage.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.storage - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import spark.storage.{RDDInfo, StorageUtils} -import spark.Utils -import spark.ui.UIUtils._ -import spark.ui.Page._ - -/** Page showing list of RDD's currently stored in the cluster */ -private[spark] class IndexPage(parent: BlockManagerUI) { - val sc = parent.sc - - def render(request: HttpServletRequest): Seq[Node] = { - val storageStatusList = sc.getExecutorStorageStatus - // Calculate macro-level statistics - - val rddHeaders = Seq( - "RDD Name", - "Storage Level", - "Cached Partitions", - "Fraction Cached", - "Size in Memory", - "Size on Disk") - val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - val content = listingTable(rddHeaders, rddRow, rdds) - - headerSparkPage(content, parent.sc, "Storage ", Storage) - } - - def rddRow(rdd: RDDInfo): Seq[Node] = { - - - - {rdd.name} - - - {rdd.storageLevel.description} - - {rdd.numCachedPartitions} - {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} - {Utils.bytesToString(rdd.memSize)} - {Utils.bytesToString(rdd.diskSize)} - - } -} diff --git a/core/src/main/scala/spark/ui/storage/RDDPage.scala b/core/src/main/scala/spark/ui/storage/RDDPage.scala deleted file mode 100644 index b128a5614d..0000000000 --- a/core/src/main/scala/spark/ui/storage/RDDPage.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.storage - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import spark.Utils -import spark.storage.{StorageStatus, StorageUtils} -import spark.storage.BlockManagerMasterActor.BlockStatus -import spark.ui.UIUtils._ -import spark.ui.Page._ - - -/** Page showing storage details for a given RDD */ -private[spark] class RDDPage(parent: BlockManagerUI) { - val sc = parent.sc - - def render(request: HttpServletRequest): Seq[Node] = { - val id = request.getParameter("id") - val prefix = "rdd_" + id.toString - val storageStatusList = sc.getExecutorStorageStatus - val filteredStorageStatusList = StorageUtils. - filterStorageStatusByPrefix(storageStatusList, prefix) - val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head - - val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage") - val workers = filteredStorageStatusList.map((prefix, _)) - val workerTable = listingTable(workerHeaders, workerRow, workers) - - val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk", - "Executors") - - val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1) - val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList) - val blocks = blockStatuses.map { - case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN"))) - } - val blockTable = listingTable(blockHeaders, blockRow, blocks) - - val content = -
    -
    -
      -
    • - Storage Level: - {rddInfo.storageLevel.description} -
    • -
    • - Cached Partitions: - {rddInfo.numCachedPartitions} -
    • -
    • - Total Partitions: - {rddInfo.numPartitions} -
    • -
    • - Memory Size: - {Utils.bytesToString(rddInfo.memSize)} -
    • -
    • - Disk Size: - {Utils.bytesToString(rddInfo.diskSize)} -
    • -
    -
    -
    - -
    -
    -

    Data Distribution on {workers.size} Executors

    - {workerTable} -
    -
    - -
    -
    -

    {blocks.size} Partitions

    - {blockTable} -
    -
    ; - - headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage) - } - - def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = { - val (id, block, locations) = row - - {id} - - {block.storageLevel.description} - - - {Utils.bytesToString(block.memSize)} - - - {Utils.bytesToString(block.diskSize)} - - - {locations.map(l => {l}
    )} - - - } - - def workerRow(worker: (String, StorageStatus)): Seq[Node] = { - val (prefix, status) = worker - - {status.blockManagerId.host + ":" + status.blockManagerId.port} - - {Utils.bytesToString(status.memUsed(prefix))} - ({Utils.bytesToString(status.memRemaining)} Remaining) - - {Utils.bytesToString(status.diskUsed(prefix))} - - } -} diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala deleted file mode 100644 index 9233277bdb..0000000000 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import akka.actor.{ActorSystem, ExtendedActorSystem} -import com.typesafe.config.ConfigFactory -import akka.util.duration._ -import akka.remote.RemoteActorRefProvider - - -/** - * Various utility classes for working with Akka. - */ -private[spark] object AkkaUtils { - - /** - * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the - * ActorSystem itself and its port (which is hard to get from Akka). - * - * Note: the `name` parameter is important, as even if a client sends a message to right - * host + port, if the system name is incorrect, Akka will drop the message. - */ - def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { - val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt - val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt - val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt - val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt - val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" - // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. - val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt - - val akkaConf = ConfigFactory.parseString(""" - akka.daemonic = on - akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] - akka.stdout-loglevel = "ERROR" - akka.actor.provider = "akka.remote.RemoteActorRefProvider" - akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" - akka.remote.netty.hostname = "%s" - akka.remote.netty.port = %d - akka.remote.netty.connection-timeout = %ds - akka.remote.netty.message-frame-size = %d MiB - akka.remote.netty.execution-pool-size = %d - akka.actor.default-dispatcher.throughput = %d - akka.remote.log-remote-lifecycle-events = %s - akka.remote.netty.write-timeout = %ds - """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, - lifecycleEvents, akkaWriteTimeout)) - - val actorSystem = ActorSystem(name, akkaConf) - - // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a - // hack because Akka doesn't let you figure out the port through the public API yet. - val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider - val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get - return (actorSystem, boundPort) - } -} diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala deleted file mode 100644 index 0575497f5d..0000000000 --- a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.io.Serializable -import java.util.{PriorityQueue => JPriorityQueue} -import scala.collection.generic.Growable -import scala.collection.JavaConverters._ - -/** - * Bounded priority queue. This class wraps the original PriorityQueue - * class and modifies it such that only the top K elements are retained. - * The top K elements are defined by an implicit Ordering[A]. - */ -class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) - extends Iterable[A] with Growable[A] with Serializable { - - private val underlying = new JPriorityQueue[A](maxSize, ord) - - override def iterator: Iterator[A] = underlying.iterator.asScala - - override def ++=(xs: TraversableOnce[A]): this.type = { - xs.foreach { this += _ } - this - } - - override def +=(elem: A): this.type = { - if (size < maxSize) underlying.offer(elem) - else maybeReplaceLowest(elem) - this - } - - override def +=(elem1: A, elem2: A, elems: A*): this.type = { - this += elem1 += elem2 ++= elems - } - - override def clear() { underlying.clear() } - - private def maybeReplaceLowest(a: A): Boolean = { - val head = underlying.peek() - if (head != null && ord.gt(a, head)) { - underlying.poll() - underlying.offer(a) - } else false - } -} - diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala deleted file mode 100644 index 47a28e2f76..0000000000 --- a/core/src/main/scala/spark/util/ByteBufferInputStream.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.io.InputStream -import java.nio.ByteBuffer -import spark.storage.BlockManager - -/** - * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose() - * at the end of the stream (e.g. to close a memory-mapped file). - */ -private[spark] -class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) - extends InputStream { - - override def read(): Int = { - if (buffer == null || buffer.remaining() == 0) { - cleanUp() - -1 - } else { - buffer.get() & 0xFF - } - } - - override def read(dest: Array[Byte]): Int = { - read(dest, 0, dest.length) - } - - override def read(dest: Array[Byte], offset: Int, length: Int): Int = { - if (buffer == null || buffer.remaining() == 0) { - cleanUp() - -1 - } else { - val amountToGet = math.min(buffer.remaining(), length) - buffer.get(dest, offset, amountToGet) - amountToGet - } - } - - override def skip(bytes: Long): Long = { - if (buffer != null) { - val amountToSkip = math.min(bytes, buffer.remaining).toInt - buffer.position(buffer.position + amountToSkip) - if (buffer.remaining() == 0) { - cleanUp() - } - amountToSkip - } else { - 0L - } - } - - /** - * Clean up the buffer, and potentially dispose of it using BlockManager.dispose(). - */ - private def cleanUp() { - if (buffer != null) { - if (dispose) { - BlockManager.dispose(buffer) - } - buffer = null - } - } -} diff --git a/core/src/main/scala/spark/util/Clock.scala b/core/src/main/scala/spark/util/Clock.scala deleted file mode 100644 index aa71a5b442..0000000000 --- a/core/src/main/scala/spark/util/Clock.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -/** - * An interface to represent clocks, so that they can be mocked out in unit tests. - */ -private[spark] trait Clock { - def getTime(): Long -} - -private[spark] object SystemClock extends Clock { - def getTime(): Long = System.currentTimeMillis() -} diff --git a/core/src/main/scala/spark/util/CompletionIterator.scala b/core/src/main/scala/spark/util/CompletionIterator.scala deleted file mode 100644 index 210450892b..0000000000 --- a/core/src/main/scala/spark/util/CompletionIterator.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -/** - * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements - */ -abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{ - def next = sub.next - def hasNext = { - val r = sub.hasNext - if (!r) { - completion - } - r - } - - def completion() -} - -object CompletionIterator { - def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = { - new CompletionIterator[A,I](sub) { - def completion() = completionFunction - } - } -} diff --git a/core/src/main/scala/spark/util/Distribution.scala b/core/src/main/scala/spark/util/Distribution.scala deleted file mode 100644 index 5d4d7a6c50..0000000000 --- a/core/src/main/scala/spark/util/Distribution.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.io.PrintStream - -/** - * Util for getting some stats from a small sample of numeric values, with some handy summary functions. - * - * Entirely in memory, not intended as a good way to compute stats over large data sets. - * - * Assumes you are giving it a non-empty set of data - */ -class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) { - require(startIdx < endIdx) - def this(data: Traversable[Double]) = this(data.toArray, 0, data.size) - java.util.Arrays.sort(data, startIdx, endIdx) - val length = endIdx - startIdx - - val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0) - - /** - * Get the value of the distribution at the given probabilities. Probabilities should be - * given from 0 to 1 - * @param probabilities - */ - def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) = { - probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))} - } - - private def closestIndex(p: Double) = { - math.min((p * length).toInt + startIdx, endIdx - 1) - } - - def showQuantiles(out: PrintStream = System.out) = { - out.println("min\t25%\t50%\t75%\tmax") - getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")} - out.println - } - - def statCounter = StatCounter(data.slice(startIdx, endIdx)) - - /** - * print a summary of this distribution to the given PrintStream. - * @param out - */ - def summary(out: PrintStream = System.out) { - out.println(statCounter) - showQuantiles(out) - } -} - -object Distribution { - - def apply(data: Traversable[Double]): Option[Distribution] = { - if (data.size > 0) - Some(new Distribution(data)) - else - None - } - - def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) { - out.println("min\t25%\t50%\t75%\tmax") - quantiles.foreach{q => out.print(q + "\t")} - out.println - } -} diff --git a/core/src/main/scala/spark/util/IdGenerator.scala b/core/src/main/scala/spark/util/IdGenerator.scala deleted file mode 100644 index 3422280559..0000000000 --- a/core/src/main/scala/spark/util/IdGenerator.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.util.concurrent.atomic.AtomicInteger - -/** - * A util used to get a unique generation ID. This is a wrapper around Java's - * AtomicInteger. An example usage is in BlockManager, where each BlockManager - * instance would start an Akka actor and we use this utility to assign the Akka - * actors unique names. - */ -private[spark] class IdGenerator { - private var id = new AtomicInteger - def next: Int = id.incrementAndGet -} diff --git a/core/src/main/scala/spark/util/IntParam.scala b/core/src/main/scala/spark/util/IntParam.scala deleted file mode 100644 index daf0d58fa2..0000000000 --- a/core/src/main/scala/spark/util/IntParam.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -/** - * An extractor object for parsing strings into integers. - */ -private[spark] object IntParam { - def unapply(str: String): Option[Int] = { - try { - Some(str.toInt) - } catch { - case e: NumberFormatException => None - } - } -} diff --git a/core/src/main/scala/spark/util/MemoryParam.scala b/core/src/main/scala/spark/util/MemoryParam.scala deleted file mode 100644 index 298562323a..0000000000 --- a/core/src/main/scala/spark/util/MemoryParam.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import spark.Utils - -/** - * An extractor object for parsing JVM memory strings, such as "10g", into an Int representing - * the number of megabytes. Supports the same formats as Utils.memoryStringToMb. - */ -private[spark] object MemoryParam { - def unapply(str: String): Option[Int] = { - try { - Some(Utils.memoryStringToMb(str)) - } catch { - case e: NumberFormatException => None - } - } -} diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala deleted file mode 100644 index 92909e0959..0000000000 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} -import java.util.{TimerTask, Timer} -import spark.Logging - - -/** - * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) - */ -class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - private val delaySeconds = MetadataCleaner.getDelaySeconds - private val periodSeconds = math.max(10, delaySeconds / 10) - private val timer = new Timer(name + " cleanup timer", true) - - private val task = new TimerTask { - override def run() { - try { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran metadata cleaner for " + name) - } catch { - case e: Exception => logError("Error running cleanup task for " + name, e) - } - } - } - - if (delaySeconds > 0) { - logDebug( - "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + - "and period of " + periodSeconds + " secs") - timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) - } - - def cancel() { - timer.cancel() - } -} - - -object MetadataCleaner { - def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt - def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) } -} - diff --git a/core/src/main/scala/spark/util/MutablePair.scala b/core/src/main/scala/spark/util/MutablePair.scala deleted file mode 100644 index 78d404e66b..0000000000 --- a/core/src/main/scala/spark/util/MutablePair.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - - -/** - * A tuple of 2 elements. This can be used as an alternative to Scala's Tuple2 when we want to - * minimize object allocation. - * - * @param _1 Element 1 of this MutablePair - * @param _2 Element 2 of this MutablePair - */ -case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1, - @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] - (var _1: T1, var _2: T2) - extends Product2[T1, T2] -{ - override def toString = "(" + _1 + "," + _2 + ")" - - override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] -} diff --git a/core/src/main/scala/spark/util/NextIterator.scala b/core/src/main/scala/spark/util/NextIterator.scala deleted file mode 100644 index 22163ece8d..0000000000 --- a/core/src/main/scala/spark/util/NextIterator.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -/** Provides a basic/boilerplate Iterator implementation. */ -private[spark] abstract class NextIterator[U] extends Iterator[U] { - - private var gotNext = false - private var nextValue: U = _ - private var closed = false - protected var finished = false - - /** - * Method for subclasses to implement to provide the next element. - * - * If no next element is available, the subclass should set `finished` - * to `true` and may return any value (it will be ignored). - * - * This convention is required because `null` may be a valid value, - * and using `Option` seems like it might create unnecessary Some/None - * instances, given some iterators might be called in a tight loop. - * - * @return U, or set 'finished' when done - */ - protected def getNext(): U - - /** - * Method for subclasses to implement when all elements have been successfully - * iterated, and the iteration is done. - * - * Note: `NextIterator` cannot guarantee that `close` will be - * called because it has no control over what happens when an exception - * happens in the user code that is calling hasNext/next. - * - * Ideally you should have another try/catch, as in HadoopRDD, that - * ensures any resources are closed should iteration fail. - */ - protected def close() - - /** - * Calls the subclass-defined close method, but only once. - * - * Usually calling `close` multiple times should be fine, but historically - * there have been issues with some InputFormats throwing exceptions. - */ - def closeIfNeeded() { - if (!closed) { - close() - closed = true - } - } - - override def hasNext: Boolean = { - if (!finished) { - if (!gotNext) { - nextValue = getNext() - if (finished) { - closeIfNeeded() - } - gotNext = true - } - } - !finished - } - - override def next(): U = { - if (!hasNext) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } -} diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala deleted file mode 100644 index 00f782bbe7..0000000000 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import scala.annotation.tailrec - -import java.io.OutputStream -import java.util.concurrent.TimeUnit._ - -class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { - val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) - val CHUNK_SIZE = 8192 - var lastSyncTime = System.nanoTime - var bytesWrittenSinceSync: Long = 0 - - override def write(b: Int) { - waitToWrite(1) - out.write(b) - } - - override def write(bytes: Array[Byte]) { - write(bytes, 0, bytes.length) - } - - @tailrec - override final def write(bytes: Array[Byte], offset: Int, length: Int) { - val writeSize = math.min(length - offset, CHUNK_SIZE) - if (writeSize > 0) { - waitToWrite(writeSize) - out.write(bytes, offset, writeSize) - write(bytes, offset + writeSize, length) - } - } - - override def flush() { - out.flush() - } - - override def close() { - out.close() - } - - @tailrec - private def waitToWrite(numBytes: Int) { - val now = System.nanoTime - val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) - val rate = bytesWrittenSinceSync.toDouble / elapsedSecs - if (rate < bytesPerSec) { - // It's okay to write; just update some variables and return - bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + SYNC_INTERVAL) { - // Sync interval has passed; let's resync - lastSyncTime = now - bytesWrittenSinceSync = numBytes - } - } else { - // Calculate how much time we should sleep to bring ourselves to the desired rate. - // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) - val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) - if (sleepTime > 0) Thread.sleep(sleepTime) - waitToWrite(numBytes) - } - } -} diff --git a/core/src/main/scala/spark/util/SerializableBuffer.scala b/core/src/main/scala/spark/util/SerializableBuffer.scala deleted file mode 100644 index 7e6842628a..0000000000 --- a/core/src/main/scala/spark/util/SerializableBuffer.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.nio.ByteBuffer -import java.io.{IOException, ObjectOutputStream, EOFException, ObjectInputStream} -import java.nio.channels.Channels - -/** - * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make - * it easier to pass ByteBuffers in case class messages. - */ -private[spark] -class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { - def value = buffer - - private def readObject(in: ObjectInputStream) { - val length = in.readInt() - buffer = ByteBuffer.allocate(length) - var amountRead = 0 - val channel = Channels.newChannel(in) - while (amountRead < length) { - val ret = channel.read(buffer) - if (ret == -1) { - throw new EOFException("End of file before fully reading buffer") - } - amountRead += ret - } - buffer.rewind() // Allow us to read it later - } - - private def writeObject(out: ObjectOutputStream) { - out.writeInt(buffer.limit()) - if (Channels.newChannel(out).write(buffer) != buffer.limit()) { - throw new IOException("Could not fully write buffer to output stream") - } - buffer.rewind() // Allow us to write it again later - } -} diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala deleted file mode 100644 index 76358d4151..0000000000 --- a/core/src/main/scala/spark/util/StatCounter.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -/** - * A class for tracking the statistics of a set of numbers (count, mean and variance) in a - * numerically robust way. Includes support for merging two StatCounters. Based on - * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Welford and Chan's algorithms for running variance]]. - * - * @constructor Initialize the StatCounter with the given values. - */ -class StatCounter(values: TraversableOnce[Double]) extends Serializable { - private var n: Long = 0 // Running count of our values - private var mu: Double = 0 // Running mean of our values - private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) - - merge(values) - - /** Initialize the StatCounter with no values. */ - def this() = this(Nil) - - /** Add a value into this StatCounter, updating the internal statistics. */ - def merge(value: Double): StatCounter = { - val delta = value - mu - n += 1 - mu += delta / n - m2 += delta * (value - mu) - this - } - - /** Add multiple values into this StatCounter, updating the internal statistics. */ - def merge(values: TraversableOnce[Double]): StatCounter = { - values.foreach(v => merge(v)) - this - } - - /** Merge another StatCounter into this one, adding up the internal statistics. */ - def merge(other: StatCounter): StatCounter = { - if (other == this) { - merge(other.copy()) // Avoid overwriting fields in a weird order - } else { - if (n == 0) { - mu = other.mu - m2 = other.m2 - n = other.n - } else if (other.n != 0) { - val delta = other.mu - mu - if (other.n * 10 < n) { - mu = mu + (delta * other.n) / (n + other.n) - } else if (n * 10 < other.n) { - mu = other.mu - (delta * n) / (n + other.n) - } else { - mu = (mu * n + other.mu * other.n) / (n + other.n) - } - m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) - n += other.n - } - this - } - } - - /** Clone this StatCounter */ - def copy(): StatCounter = { - val other = new StatCounter - other.n = n - other.mu = mu - other.m2 = m2 - other - } - - def count: Long = n - - def mean: Double = mu - - def sum: Double = n * mu - - /** Return the variance of the values. */ - def variance: Double = { - if (n == 0) - Double.NaN - else - m2 / n - } - - /** - * Return the sample variance, which corrects for bias in estimating the variance by dividing - * by N-1 instead of N. - */ - def sampleVariance: Double = { - if (n <= 1) - Double.NaN - else - m2 / (n - 1) - } - - /** Return the standard deviation of the values. */ - def stdev: Double = math.sqrt(variance) - - /** - * Return the sample standard deviation of the values, which corrects for bias in estimating the - * variance by dividing by N-1 instead of N. - */ - def sampleStdev: Double = math.sqrt(sampleVariance) - - override def toString: String = { - "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev) - } -} - -object StatCounter { - /** Build a StatCounter from a list of values. */ - def apply(values: TraversableOnce[Double]) = new StatCounter(values) - - /** Build a StatCounter from a list of values passed as variable-length arguments. */ - def apply(values: Double*) = new StatCounter(values) -} diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala deleted file mode 100644 index 07772a0afb..0000000000 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions -import scala.collection.mutable.Map -import scala.collection.immutable -import spark.scheduler.MapStatus - -/** - * This is a custom implementation of scala.collection.mutable.Map which stores the insertion - * time stamp along with each key-value pair. Key-value pairs that are older than a particular - * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in - * replacement of scala.collection.mutable.HashMap. - */ -class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { - val internalMap = new ConcurrentHashMap[A, (B, Long)]() - - def get(key: A): Option[B] = { - val value = internalMap.get(key) - if (value != null) Some(value._1) else None - } - - def iterator: Iterator[(A, B)] = { - val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) - } - - override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { - val newMap = new TimeStampedHashMap[A, B1] - newMap.internalMap.putAll(this.internalMap) - newMap.internalMap.put(kv._1, (kv._2, currentTime)) - newMap - } - - override def - (key: A): Map[A, B] = { - val newMap = new TimeStampedHashMap[A, B] - newMap.internalMap.putAll(this.internalMap) - newMap.internalMap.remove(key) - newMap - } - - override def += (kv: (A, B)): this.type = { - internalMap.put(kv._1, (kv._2, currentTime)) - this - } - - // Should we return previous value directly or as Option ? - def putIfAbsent(key: A, value: B): Option[B] = { - val prev = internalMap.putIfAbsent(key, (value, currentTime)) - if (prev != null) Some(prev._1) else None - } - - - override def -= (key: A): this.type = { - internalMap.remove(key) - this - } - - override def update(key: A, value: B) { - this += ((key, value)) - } - - override def apply(key: A): B = { - val value = internalMap.get(key) - if (value == null) throw new NoSuchElementException() - value._1 - } - - override def filter(p: ((A, B)) => Boolean): Map[A, B] = { - JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) - } - - override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() - - override def size: Int = internalMap.size - - override def foreach[U](f: ((A, B)) => U) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val kv = (entry.getKey, entry.getValue._1) - f(kv) - } - } - - def toMap: immutable.Map[A, B] = iterator.toMap - - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime` - */ - def clearOldValues(threshTime: Long) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue._2 < threshTime) { - logDebug("Removing key " + entry.getKey) - iterator.remove() - } - } - } - - private def currentTime: Long = System.currentTimeMillis() - -} diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala deleted file mode 100644 index 41e3fd8cba..0000000000 --- a/core/src/main/scala/spark/util/TimeStampedHashSet.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import scala.collection.mutable.Set -import scala.collection.JavaConversions -import java.util.concurrent.ConcurrentHashMap - - -class TimeStampedHashSet[A] extends Set[A] { - val internalMap = new ConcurrentHashMap[A, Long]() - - def contains(key: A): Boolean = { - internalMap.contains(key) - } - - def iterator: Iterator[A] = { - val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(_.getKey) - } - - override def + (elem: A): Set[A] = { - val newSet = new TimeStampedHashSet[A] - newSet ++= this - newSet += elem - newSet - } - - override def - (elem: A): Set[A] = { - val newSet = new TimeStampedHashSet[A] - newSet ++= this - newSet -= elem - newSet - } - - override def += (key: A): this.type = { - internalMap.put(key, currentTime) - this - } - - override def -= (key: A): this.type = { - internalMap.remove(key) - this - } - - override def empty: Set[A] = new TimeStampedHashSet[A]() - - override def size(): Int = internalMap.size() - - override def foreach[U](f: (A) => U): Unit = { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - f(iterator.next.getKey) - } - } - - /** - * Removes old values that have timestamp earlier than `threshTime` - */ - def clearOldValues(threshTime: Long) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue < threshTime) { - iterator.remove() - } - } - } - - private def currentTime: Long = System.currentTimeMillis() -} diff --git a/core/src/main/scala/spark/util/Vector.scala b/core/src/main/scala/spark/util/Vector.scala deleted file mode 100644 index a47cac3b96..0000000000 --- a/core/src/main/scala/spark/util/Vector.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -class Vector(val elements: Array[Double]) extends Serializable { - def length = elements.length - - def apply(index: Int) = elements(index) - - def + (other: Vector): Vector = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - return Vector(length, i => this(i) + other(i)) - } - - def add(other: Vector) = this + other - - def - (other: Vector): Vector = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - return Vector(length, i => this(i) - other(i)) - } - - def subtract(other: Vector) = this - other - - def dot(other: Vector): Double = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - var ans = 0.0 - var i = 0 - while (i < length) { - ans += this(i) * other(i) - i += 1 - } - return ans - } - - /** - * return (this + plus) dot other, but without creating any intermediate storage - * @param plus - * @param other - * @return - */ - def plusDot(plus: Vector, other: Vector): Double = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - if (length != plus.length) - throw new IllegalArgumentException("Vectors of different length") - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) + plus(i)) * other(i) - i += 1 - } - return ans - } - - def += (other: Vector): Vector = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - var i = 0 - while (i < length) { - elements(i) += other(i) - i += 1 - } - this - } - - def addInPlace(other: Vector) = this +=other - - def * (scale: Double): Vector = Vector(length, i => this(i) * scale) - - def multiply (d: Double) = this * d - - def / (d: Double): Vector = this * (1 / d) - - def divide (d: Double) = this / d - - def unary_- = this * -1 - - def sum = elements.reduceLeft(_ + _) - - def squaredDist(other: Vector): Double = { - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) - other(i)) * (this(i) - other(i)) - i += 1 - } - return ans - } - - def dist(other: Vector): Double = math.sqrt(squaredDist(other)) - - override def toString = elements.mkString("(", ", ", ")") -} - -object Vector { - def apply(elements: Array[Double]) = new Vector(elements) - - def apply(elements: Double*) = new Vector(elements.toArray) - - def apply(length: Int, initializer: Int => Double): Vector = { - val elements: Array[Double] = Array.tabulate(length)(initializer) - return new Vector(elements) - } - - def zeros(length: Int) = new Vector(new Array[Double](length)) - - def ones(length: Int) = Vector(length, _ => 1) - - class Multiplier(num: Double) { - def * (vec: Vector) = vec * num - } - - implicit def doubleToMultiplier(num: Double) = new Multiplier(num) - - implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] { - def addInPlace(t1: Vector, t2: Vector) = t1 + t2 - - def zero(initialValue: Vector) = Vector.zeros(initialValue.length) - } - -} diff --git a/core/src/test/resources/test_metrics_config.properties b/core/src/test/resources/test_metrics_config.properties index 2b31ddf2eb..056a158456 100644 --- a/core/src/test/resources/test_metrics_config.properties +++ b/core/src/test/resources/test_metrics_config.properties @@ -1,6 +1,6 @@ *.sink.console.period = 10 *.sink.console.unit = seconds -*.source.jvm.class = spark.metrics.source.JvmSource +*.source.jvm.class = org.apache.spark.metrics.source.JvmSource master.sink.console.period = 20 master.sink.console.unit = minutes diff --git a/core/src/test/resources/test_metrics_system.properties b/core/src/test/resources/test_metrics_system.properties index d5479f0298..6f5ecea93a 100644 --- a/core/src/test/resources/test_metrics_system.properties +++ b/core/src/test/resources/test_metrics_system.properties @@ -1,7 +1,7 @@ *.sink.console.period = 10 *.sink.console.unit = seconds -test.sink.console.class = spark.metrics.sink.ConsoleSink -test.sink.dummy.class = spark.metrics.sink.DummySink -test.source.dummy.class = spark.metrics.source.DummySource +test.sink.console.class = org.apache.spark.metrics.sink.ConsoleSink +test.sink.dummy.class = org.apache.spark.metrics.sink.DummySink +test.source.dummy.class = org.apache.spark.metrics.source.DummySource test.sink.console.period = 20 test.sink.console.unit = minutes diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala new file mode 100644 index 0000000000..4434f3b87c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import collection.mutable +import java.util.Random +import scala.math.exp +import scala.math.signum +import org.apache.spark.SparkContext._ + +class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { + + test ("basic accumulation"){ + sc = new SparkContext("local", "test") + val acc : Accumulator[Int] = sc.accumulator(0) + + val d = sc.parallelize(1 to 20) + d.foreach{x => acc += x} + acc.value should be (210) + + + val longAcc = sc.accumulator(0l) + val maxInt = Integer.MAX_VALUE.toLong + d.foreach{x => longAcc += maxInt + x} + longAcc.value should be (210l + maxInt * 20) + } + + test ("value not assignable from tasks") { + sc = new SparkContext("local", "test") + val acc : Accumulator[Int] = sc.accumulator(0) + + val d = sc.parallelize(1 to 20) + evaluating {d.foreach{x => acc.value = x}} should produce [Exception] + } + + test ("add value to collection accumulators") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val d = sc.parallelize(1 to maxI) + d.foreach { + x => acc += x + } + val v = acc.value.asInstanceOf[mutable.Set[Int]] + for (i <- 1 to maxI) { + v should contain(i) + } + resetSparkContext() + } + } + + implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] { + def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = { + t1 ++= t2 + t1 + } + def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = { + t1 += t2 + t1 + } + def zero(t: mutable.Set[Any]) : mutable.Set[Any] = { + new mutable.HashSet[Any]() + } + } + + test ("value not readable in tasks") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val d = sc.parallelize(1 to maxI) + evaluating { + d.foreach { + x => acc.value += x + } + } should produce [SparkException] + resetSparkContext() + } + } + + test ("collection accumulators") { + val maxI = 1000 + for (nThreads <- List(1, 10)) { + // test single & multi-threaded + sc = new SparkContext("local[" + nThreads + "]", "test") + val setAcc = sc.accumulableCollection(mutable.HashSet[Int]()) + val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]()) + val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]()) + val d = sc.parallelize((1 to maxI) ++ (1 to maxI)) + d.foreach { + x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} + } + + // Note that this is typed correctly -- no casts necessary + setAcc.value.size should be (maxI) + bufferAcc.value.size should be (2 * maxI) + mapAcc.value.size should be (maxI) + for (i <- 1 to maxI) { + setAcc.value should contain(i) + bufferAcc.value should contain(i) + mapAcc.value should contain (i -> i.toString) + } + resetSparkContext() + } + } + + test ("localValue readable in tasks") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val groupedInts = (1 to (maxI/20)).map {x => (20 * (x - 1) to 20 * x).toSet} + val d = sc.parallelize(groupedInts) + d.foreach { + x => acc.localValue ++= x + } + acc.value should be ( (0 to maxI).toSet) + resetSparkContext() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala new file mode 100644 index 0000000000..b3a53d928b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +class BroadcastSuite extends FunSuite with LocalSparkContext { + + test("basic broadcast") { + sc = new SparkContext("local", "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === Set((1, 10), (2, 10))) + } + + test("broadcast variables accessed in multiple threads") { + sc = new SparkContext("local[10]", "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + } +} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala new file mode 100644 index 0000000000..23b14f4245 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import java.io.File +import org.apache.spark.rdd._ +import org.apache.spark.SparkContext._ +import storage.StorageLevel + +class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { + initLogging() + + var checkpointDir: File = _ + val partitioner = new HashPartitioner(2) + + override def beforeEach() { + super.beforeEach() + checkpointDir = File.createTempFile("temp", "") + checkpointDir.delete() + sc = new SparkContext("local", "test") + sc.setCheckpointDir(checkpointDir.toString) + } + + override def afterEach() { + super.afterEach() + if (checkpointDir != null) { + checkpointDir.delete() + } + } + + test("basic checkpointing") { + val parCollection = sc.makeRDD(1 to 4) + val flatMappedRDD = parCollection.flatMap(x => 1 to x) + flatMappedRDD.checkpoint() + assert(flatMappedRDD.dependencies.head.rdd == parCollection) + val result = flatMappedRDD.collect() + assert(flatMappedRDD.dependencies.head.rdd != parCollection) + assert(flatMappedRDD.collect() === result) + } + + test("RDDs with one-to-one dependencies") { + testCheckpointing(_.map(x => x.toString)) + testCheckpointing(_.flatMap(x => 1 to x)) + testCheckpointing(_.filter(_ % 2 == 0)) + testCheckpointing(_.sample(false, 0.5, 0)) + testCheckpointing(_.glom()) + testCheckpointing(_.mapPartitions(_.map(_.toString))) + testCheckpointing(r => new MapPartitionsWithIndexRDD(r, + (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false )) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) + testCheckpointing(_.pipe(Seq("cat"))) + } + + test("ParallelCollection") { + val parCollection = sc.makeRDD(1 to 4, 2) + val numPartitions = parCollection.partitions.size + parCollection.checkpoint() + assert(parCollection.dependencies === Nil) + val result = parCollection.collect() + assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) + assert(parCollection.dependencies != Nil) + assert(parCollection.partitions.length === numPartitions) + assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList) + assert(parCollection.collect() === result) + } + + test("BlockRDD") { + val blockId = "id" + val blockManager = SparkEnv.get.blockManager + blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) + val blockRDD = new BlockRDD[String](sc, Array(blockId)) + val numPartitions = blockRDD.partitions.size + blockRDD.checkpoint() + val result = blockRDD.collect() + assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) + assert(blockRDD.dependencies != Nil) + assert(blockRDD.partitions.length === numPartitions) + assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList) + assert(blockRDD.collect() === result) + } + + test("ShuffledRDD") { + testCheckpointing(rdd => { + // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD + new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) + }) + } + + test("UnionRDD") { + def otherRDD = sc.makeRDD(1 to 10, 1) + + // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed. + // Current implementation of UnionRDD has transient reference to parent RDDs, + // so only the partitions will reduce in serialized size, not the RDD. + testCheckpointing(_.union(otherRDD), false, true) + testParentCheckpointing(_.union(otherRDD), false, true) + } + + test("CartesianRDD") { + def otherRDD = sc.makeRDD(1 to 10, 1) + testCheckpointing(new CartesianRDD(sc, _, otherRDD)) + + // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed + // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, + // so only the RDD will reduce in serialized size, not the partitions. + testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false) + + // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after + // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. + // Note that this test is very specific to the current implementation of CartesianRDD. + val ones = sc.makeRDD(1 to 100, 10).map(x => x) + ones.checkpoint() // checkpoint that MappedRDD + val cartesian = new CartesianRDD(sc, ones, ones) + val splitBeforeCheckpoint = + serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) + cartesian.count() // do the checkpointing + val splitAfterCheckpoint = + serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) + assert( + (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) && + (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2), + "CartesianRDD.parents not updated after parent RDD checkpointed" + ) + } + + test("CoalescedRDD") { + testCheckpointing(_.coalesce(2)) + + // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed + // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, + // so only the RDD will reduce in serialized size, not the partitions. + testParentCheckpointing(_.coalesce(2), true, false) + + // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after + // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. + // Note that this test is very specific to the current implementation of CoalescedRDDPartitions + val ones = sc.makeRDD(1 to 100, 10).map(x => x) + ones.checkpoint() // checkpoint that MappedRDD + val coalesced = new CoalescedRDD(ones, 2) + val splitBeforeCheckpoint = + serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) + coalesced.count() // do the checkpointing + val splitAfterCheckpoint = + serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) + assert( + splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head, + "CoalescedRDDPartition.parents not updated after parent RDD checkpointed" + ) + } + + test("CoGroupedRDD") { + val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD() + testCheckpointing(rdd => { + CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner) + }, false, true) + + val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD() + testParentCheckpointing(rdd => { + CheckpointSuite.cogroup( + longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) + }, false, true) + } + + test("ZippedRDD") { + testCheckpointing( + rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) + + // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed + // Current implementation of ZippedRDDPartitions has transient references to parent RDDs, + // so only the RDD will reduce in serialized size, not the partitions. + testParentCheckpointing( + rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) + } + + test("CheckpointRDD with zero partitions") { + val rdd = new BlockRDD[Int](sc, Array[String]()) + assert(rdd.partitions.size === 0) + assert(rdd.isCheckpointed === false) + rdd.checkpoint() + assert(rdd.count() === 0) + assert(rdd.isCheckpointed === true) + assert(rdd.partitions.size === 0) + } + + /** + * Test checkpointing of the final RDD generated by the given operation. By default, + * this method tests whether the size of serialized RDD has reduced after checkpointing or not. + * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or + * not, but this is not done by default as usually the partitions do not refer to any RDD and + * therefore never store the lineage. + */ + def testCheckpointing[U: ClassManifest]( + op: (RDD[Int]) => RDD[U], + testRDDSize: Boolean = true, + testRDDPartitionSize: Boolean = false + ) { + // Generate the final RDD using given RDD operation + val baseRDD = generateLongLineageRDD() + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName + val numPartitions = operatedRDD.partitions.length + + // Find serialized sizes before and after the checkpoint + val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + operatedRDD.checkpoint() + val result = operatedRDD.collect() + val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + + // Test whether the checkpoint file has been created + assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + + // Test whether dependencies have been changed from its earlier parent RDD + assert(operatedRDD.dependencies.head.rdd != parentRDD) + + // Test whether the partitions have been changed to the new Hadoop partitions + assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) + + // Test whether the number of partitions is same as before + assert(operatedRDD.partitions.length === numPartitions) + + // Test whether the data in the checkpointed RDD is same as original + assert(operatedRDD.collect() === result) + + // Test whether serialized size of the RDD has reduced. If the RDD + // does not have any dependency to another RDD (e.g., ParallelCollection, + // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. + if (testRDDSize) { + logInfo("Size of " + rddType + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + // Test whether serialized size of the partitions has reduced. If the partitions + // do not have any non-transient reference to another RDD or another RDD's partitions, it + // does not refer to a lineage and therefore may not reduce in size after checkpointing. + // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions + // must be forgotten after checkpointing (to remove all reference to parent RDDs) and + // replaced with the HadooPartitions of the checkpointed RDD. + if (testRDDPartitionSize) { + logInfo("Size of " + rddType + " partitions " + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") + assert( + splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing " + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" + ) + } + } + + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, + * this RDD will remember the partitions and therefore potentially the whole lineage. + */ + def testParentCheckpointing[U: ClassManifest]( + op: (RDD[Int]) => RDD[U], + testRDDSize: Boolean, + testRDDPartitionSize: Boolean + ) { + // Generate the final RDD using given RDD operation + val baseRDD = generateLongLineageRDD() + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.head.rdd + val rddType = operatedRDD.getClass.getSimpleName + val parentRDDType = parentRDD.getClass.getSimpleName + + // Get the partitions and dependencies of the parent in case they're lazily computed + parentRDD.dependencies + parentRDD.partitions + + // Find serialized sizes before and after the checkpoint + val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one + val result = operatedRDD.collect() + val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + + // Test whether the data in the checkpointed RDD is same as original + assert(operatedRDD.collect() === result) + + // Test whether serialized size of the RDD has reduced because of its parent being + // checkpointed. If this RDD or its parent RDD do not have any dependency + // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may + // not reduce in size after checkpointing. + if (testRDDSize) { + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + // Test whether serialized size of the partitions has reduced because of its parent being + // checkpointed. If the partitions do not have any non-transient reference to another RDD + // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce + // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent + // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's + // partitions must have changed after checkpointing. + if (testRDDPartitionSize) { + assert( + splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" + ) + } + + } + + /** + * Generate an RDD with a long lineage of one-to-one dependencies. + */ + def generateLongLineageRDD(): RDD[Int] = { + var rdd = sc.makeRDD(1 to 100, 4) + for (i <- 1 to 50) { + rdd = rdd.map(x => x + 1) + } + rdd + } + + /** + * Generate an RDD with a long lineage specifically for CoGroupedRDD. + * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage + * and narrow dependency with this RDD. This method generate such an RDD by a sequence + * of cogroups and mapValues which creates a long lineage of narrow dependencies. + */ + def generateLongLineageRDDForCoGroupedRDD() = { + val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + + def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) + + var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones) + for(i <- 1 to 10) { + cogrouped = cogrouped.mapValues(add).cogroup(ones) + } + cogrouped.mapValues(add) + } + + /** + * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. + */ + def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length, + Utils.serialize(rdd.partitions).length) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) + } +} + + +object CheckpointSuite { + // This is a custom cogroup function that does not use mapValues like + // the PairRDDFunctions.cogroup() + def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { + //println("First = " + first + ", second = " + second) + new CoGroupedRDD[K]( + Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), + part + ).asInstanceOf[RDD[(K, Seq[Seq[V]])]] + } + +} diff --git a/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala new file mode 100644 index 0000000000..8494899b98 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.NotSerializableException + +import org.scalatest.FunSuite +import org.apache.spark.LocalSparkContext._ +import SparkContext._ + +class ClosureCleanerSuite extends FunSuite { + test("closures inside an object") { + assert(TestObject.run() === 30) // 6 + 7 + 8 + 9 + } + + test("closures inside a class") { + val obj = new TestClass + assert(obj.run() === 30) // 6 + 7 + 8 + 9 + } + + test("closures inside a class with no default constructor") { + val obj = new TestClassWithoutDefaultConstructor(5) + assert(obj.run() === 30) // 6 + 7 + 8 + 9 + } + + test("closures that don't use fields of the outer class") { + val obj = new TestClassWithoutFieldAccess + assert(obj.run() === 30) // 6 + 7 + 8 + 9 + } + + test("nested closures inside an object") { + assert(TestObjectWithNesting.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 + } + + test("nested closures inside a class") { + val obj = new TestClassWithNesting(1) + assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 + } +} + +// A non-serializable class we create in closures to make sure that we aren't +// keeping references to unneeded variables from our outer closures. +class NonSerializable {} + +object TestObject { + def run(): Int = { + var nonSer = new NonSerializable + var x = 5 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } + } +} + +class TestClass extends Serializable { + var x = 5 + + def getX = x + + def run(): Int = { + var nonSer = new NonSerializable + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } + } +} + +class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { + def getX = x + + def run(): Int = { + var nonSer = new NonSerializable + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } + } +} + +// This class is not serializable, but we aren't using any of its fields in our +// closures, so they won't have a $outer pointing to it and should still work. +class TestClassWithoutFieldAccess { + var nonSer = new NonSerializable + + def run(): Int = { + var nonSer2 = new NonSerializable + var x = 5 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } + } +} + + +object TestObjectWithNesting { + def run(): Int = { + var nonSer = new NonSerializable + var answer = 0 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + var y = 1 + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + y).reduce(_ + _) + } + answer + } + } +} + +class TestClassWithNesting(val y: Int) extends Serializable { + def getY = y + + def run(): Int = { + var nonSer = new NonSerializable + var answer = 0 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + getY).reduce(_ + _) + } + answer + } + } +} diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala new file mode 100644 index 0000000000..7a856d4081 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import network.ConnectionManagerId +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.matchers.ShouldMatchers +import org.scalatest.prop.Checkers +import org.scalatest.time.{Span, Millis} +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ +import org.eclipse.jetty.server.{Server, Request, Handler} + +import com.google.common.io.Files + +import scala.collection.mutable.ArrayBuffer + +import SparkContext._ +import storage.{GetBlock, BlockManagerWorker, StorageLevel} +import ui.JettyUtils + + +class NotSerializableClass +class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} + + +class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter + with LocalSparkContext { + + val clusterUrl = "local-cluster[2,1,512]" + + after { + System.clearProperty("spark.reducer.maxMbInFlight") + System.clearProperty("spark.storage.memoryFraction") + } + + test("task throws not serializable exception") { + // Ensures that executors do not crash when an exn is not serializable. If executors crash, + // this test will hang. Correct behavior is that executors don't crash but fail tasks + // and the scheduler throws a SparkException. + + // numSlaves must be less than numPartitions + val numSlaves = 3 + val numPartitions = 10 + + sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") + val data = sc.parallelize(1 to 100, numPartitions). + map(x => throw new NotSerializableExn(new NotSerializableClass)) + intercept[SparkException] { + data.count() + } + resetSparkContext() + } + + test("local-cluster format") { + sc = new SparkContext("local-cluster[2,1,512]", "test") + assert(sc.parallelize(1 to 2, 2).count() == 2) + resetSparkContext() + sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") + assert(sc.parallelize(1 to 2, 2).count() == 2) + resetSparkContext() + sc = new SparkContext("local-cluster[2, 1, 512]", "test") + assert(sc.parallelize(1 to 2, 2).count() == 2) + resetSparkContext() + sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") + assert(sc.parallelize(1 to 2, 2).count() == 2) + resetSparkContext() + } + + test("simple groupByKey") { + sc = new SparkContext(clusterUrl, "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 5) + val groups = pairs.groupByKey(5).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey where map output sizes exceed maxMbInFlight") { + System.setProperty("spark.reducer.maxMbInFlight", "1") + sc = new SparkContext(clusterUrl, "test") + // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output + // file should be about 2.5 MB + val pairs = sc.parallelize(1 to 2000, 4).map(x => (x % 16, new Array[Byte](10000))) + val groups = pairs.groupByKey(2).map(x => (x._1, x._2.size)).collect() + assert(groups.length === 16) + assert(groups.map(_._2).sum === 2000) + // Note that spark.reducer.maxMbInFlight will be cleared in the test suite's after{} block + } + + test("accumulators") { + sc = new SparkContext(clusterUrl, "test") + val accum = sc.accumulator(0) + sc.parallelize(1 to 10, 10).foreach(x => accum += x) + assert(accum.value === 55) + } + + test("broadcast variables") { + sc = new SparkContext(clusterUrl, "test") + val array = new Array[Int](100) + val bv = sc.broadcast(array) + array(2) = 3 // Change the array -- this should not be seen on workers + val rdd = sc.parallelize(1 to 10, 10) + val sum = rdd.map(x => bv.value.sum).reduce(_ + _) + assert(sum === 0) + } + + test("repeatedly failing task") { + sc = new SparkContext(clusterUrl, "test") + val accum = sc.accumulator(0) + val thrown = intercept[SparkException] { + sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("more than 4 times")) + } + + test("caching") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).cache() + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching on disk") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory, serialized, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching on disk, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory and disk, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory and disk, serialized, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2) + + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + + // Get all the locations of the first partition and try to fetch the partitions + // from those locations. + val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray + val blockId = blockIds(0) + val blockManager = SparkEnv.get.blockManager + blockManager.master.getLocations(blockId).foreach(id => { + val bytes = BlockManagerWorker.syncGetBlock( + GetBlock(blockId), ConnectionManagerId(id.host, id.port)) + val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList + assert(deserialized === (1 to 100).toList) + }) + } + + test("compute without caching when no partitions fit in memory") { + System.setProperty("spark.storage.memoryFraction", "0.0001") + sc = new SparkContext(clusterUrl, "test") + // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache + // to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory + val data = sc.parallelize(1 to 4000000, 2).persist(StorageLevel.MEMORY_ONLY_SER) + assert(data.count() === 4000000) + assert(data.count() === 4000000) + assert(data.count() === 4000000) + System.clearProperty("spark.storage.memoryFraction") + } + + test("compute when only some partitions fit in memory") { + System.setProperty("spark.storage.memoryFraction", "0.01") + sc = new SparkContext(clusterUrl, "test") + // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache + // to only 5 MB (0.01 of 512 MB), so not all of it will fit in memory; we use 20 partitions + // to make sure that *some* of them do fit though + val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER) + assert(data.count() === 4000000) + assert(data.count() === 4000000) + assert(data.count() === 4000000) + System.clearProperty("spark.storage.memoryFraction") + } + + test("passing environment variables to cluster") { + sc = new SparkContext(clusterUrl, "test", null, Nil, Map("TEST_VAR" -> "TEST_VALUE")) + val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect() + assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE")) + } + + test("recover from node failures") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(Seq(true, true), 2) + assert(data.count === 2) // force executors to start + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).collect.size === 2) + } + + test("recover from repeated node failures during shuffle-map") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, false), 2) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) + } + } + + test("recover from repeated node failures during shuffle-reduce") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, true), 2) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + // This relies on mergeCombiners being used to perform the actual reduce for this + // test to actually be testing what it claims. + val grouped = data.map(x => x -> x).combineByKey( + x => x, + (x: Boolean, y: Boolean) => x, + (x: Boolean, y: Boolean) => failOnMarkedIdentity(x) + ) + assert(grouped.collect.size === 1) + } + } + + test("recover from node failures with replication") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + // Using more than two nodes so we don't have a symmetric communication pattern and might + // cache a partially correct list of peers. + sc = new SparkContext("local-cluster[3,1,512]", "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, false, false, false), 4) + data.persist(StorageLevel.MEMORY_ONLY_2) + + assert(data.count === 4) + assert(data.map(markNodeIfIdentity).collect.size === 4) + assert(data.map(failOnMarkedIdentity).collect.size === 4) + + // Create a new replicated RDD to make sure that cached peer information doesn't cause + // problems. + val data2 = sc.parallelize(Seq(true, true), 2).persist(StorageLevel.MEMORY_ONLY_2) + assert(data2.count === 2) + } + } + + test("unpersist RDDs") { + DistributedSuite.amMaster = true + sc = new SparkContext("local-cluster[3,1,512]", "test") + val data = sc.parallelize(Seq(true, false, false, false), 4) + data.persist(StorageLevel.MEMORY_ONLY_2) + data.count + assert(sc.persistentRdds.isEmpty === false) + data.unpersist() + assert(sc.persistentRdds.isEmpty === true) + + failAfter(Span(3000, Millis)) { + try { + while (! sc.getRDDStorageInfo.isEmpty) { + Thread.sleep(200) + } + } catch { + case _ => { Thread.sleep(10) } + // Do nothing. We might see exceptions because block manager + // is racing this thread to remove entries from the driver. + } + } + } + + test("job should fail if TaskResult exceeds Akka frame size") { + // We must use local-cluster mode since results are returned differently + // when running under LocalScheduler: + sc = new SparkContext("local-cluster[1,1,512]", "test") + val akkaFrameSize = + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)} + val exception = intercept[SparkException] { + rdd.reduce((x, y) => x) + } + exception.getMessage should endWith("result exceeded Akka frame size") + } +} + +object DistributedSuite { + // Indicates whether this JVM is marked for failure. + var mark = false + + // Set by test to remember if we are in the driver program so we can assert + // that we are not. + var amMaster = false + + // Act like an identity function, but if the argument is true, set mark to true. + def markNodeIfIdentity(item: Boolean): Boolean = { + if (item) { + assert(!amMaster) + mark = true + } + item + } + + // Act like an identity function, but if mark was set to true previously, fail, + // crashing the entire JVM. + def failOnMarkedIdentity(item: Boolean): Boolean = { + if (mark) { + System.exit(42) + } + item + } +} diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala new file mode 100644 index 0000000000..b08aad1a6f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File + +import org.apache.log4j.Logger +import org.apache.log4j.Level + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.prop.TableDrivenPropertyChecks._ +import org.scalatest.time.SpanSugar._ + +class DriverSuite extends FunSuite with Timeouts { + test("driver should exit after finishing") { + assert(System.getenv("SPARK_HOME") != null) + // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" + val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) + forAll(masters) { (master: String) => + failAfter(30 seconds) { + Utils.execute(Seq("./spark-class", "org.apache.spark.DriverWithoutCleanup", master), + new File(System.getenv("SPARK_HOME"))) + } + } + } +} + +/** + * Program that creates a Spark driver but doesn't call SparkContext.stop() or + * Sys.exit() after finishing. + */ +object DriverWithoutCleanup { + def main(args: Array[String]) { + Logger.getRootLogger().setLevel(Level.WARN) + val sc = new SparkContext(args(0), "DriverWithoutCleanup") + sc.parallelize(1 to 100, 4).count() + } +} diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala new file mode 100644 index 0000000000..ee89a7a387 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import SparkContext._ + +// Common state shared by FailureSuite-launched tasks. We use a global object +// for this because any local variables used in the task closures will rightfully +// be copied for each task, so there's no other way for them to share state. +object FailureSuiteState { + var tasksRun = 0 + var tasksFailed = 0 + + def clear() { + synchronized { + tasksRun = 0 + tasksFailed = 0 + } + } +} + +class FailureSuite extends FunSuite with LocalSparkContext { + + // Run a 3-task map job in which task 1 deterministically fails once, and check + // whether the job completes successfully and we ran 4 tasks in total. + test("failure in a single-stage job") { + sc = new SparkContext("local[1,1]", "test") + val results = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 1 && FailureSuiteState.tasksFailed == 0) { + FailureSuiteState.tasksFailed += 1 + throw new Exception("Intentional task failure") + } + } + x * x + }.collect() + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + assert(results.toList === List(1,4,9)) + FailureSuiteState.clear() + } + + // Run a map-reduce job in which a reduce task deterministically fails once. + test("failure in a two-stage job") { + sc = new SparkContext("local[1,1]", "test") + val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { + case (k, v) => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (k == 1 && FailureSuiteState.tasksFailed == 0) { + FailureSuiteState.tasksFailed += 1 + throw new Exception("Intentional task failure") + } + } + (k, v(0) * v(0)) + }.collect() + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + assert(results.toSet === Set((1, 1), (2, 4), (3, 9))) + FailureSuiteState.clear() + } + + test("failure because task results are not serializable") { + sc = new SparkContext("local[1,1]", "test") + val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) + + val thrown = intercept[SparkException] { + results.collect() + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("NotSerializableException")) + + FailureSuiteState.clear() + } + + test("failure because task closure is not serializable") { + sc = new SparkContext("local[1,1]", "test") + val a = new NonSerializable + + // Non-serializable closure in the final result stage + val thrown = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map(x => a).count() + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("NotSerializableException")) + + // Non-serializable closure in an earlier stage + val thrown1 = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() + } + assert(thrown1.getClass === classOf[SparkException]) + assert(thrown1.getMessage.contains("NotSerializableException")) + + // Non-serializable closure in foreach function + val thrown2 = intercept[SparkException] { + sc.parallelize(1 to 10, 2).foreach(x => println(a)) + } + assert(thrown2.getClass === classOf[SparkException]) + assert(thrown2.getMessage.contains("NotSerializableException")) + + FailureSuiteState.clear() + } + + // TODO: Need to add tests with shuffle fetch failures. +} + + diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala new file mode 100644 index 0000000000..35d1d41af1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import com.google.common.io.Files +import org.scalatest.FunSuite +import java.io.{File, PrintWriter, FileReader, BufferedReader} +import SparkContext._ + +class FileServerSuite extends FunSuite with LocalSparkContext { + + @transient var tmpFile: File = _ + @transient var testJarFile: File = _ + + override def beforeEach() { + super.beforeEach() + // Create a sample text file + val tmpdir = new File(Files.createTempDir(), "test") + tmpdir.mkdir() + tmpFile = new File(tmpdir, "FileServerSuite.txt") + val pw = new PrintWriter(tmpFile) + pw.println("100") + pw.close() + } + + override def afterEach() { + super.afterEach() + // Clean up downloaded file + if (tmpFile.exists) { + tmpFile.delete() + } + } + + test("Distributing files locally") { + sc = new SparkContext("local[4]", "test") + sc.addFile(tmpFile.toString) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect() + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + + test("Distributing files locally using URL as input") { + // addFile("file:///....") + sc = new SparkContext("local[4]", "test") + sc.addFile(new File(tmpFile.toString).toURI.toString) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect() + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + + test ("Dynamically adding JARS locally") { + sc = new SparkContext("local[4]", "test") + val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() + sc.addJar(sampleJarFile) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) + val result = sc.parallelize(testData).reduceByKey { (x,y) => + val fac = Thread.currentThread.getContextClassLoader() + .loadClass("org.uncommons.maths.Maths") + .getDeclaredMethod("factorial", classOf[Int]) + val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + a + b + }.collect() + assert(result.toSet === Set((1,2), (2,7), (3,121))) + } + + test("Distributing files on a standalone cluster") { + sc = new SparkContext("local-cluster[1,1,512]", "test") + sc.addFile(tmpFile.toString) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect() + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + + test ("Dynamically adding JARS on a standalone cluster") { + sc = new SparkContext("local-cluster[1,1,512]", "test") + val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() + sc.addJar(sampleJarFile) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) + val result = sc.parallelize(testData).reduceByKey { (x,y) => + val fac = Thread.currentThread.getContextClassLoader() + .loadClass("org.uncommons.maths.Maths") + .getDeclaredMethod("factorial", classOf[Int]) + val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + a + b + }.collect() + assert(result.toSet === Set((1,2), (2,7), (3,121))) + } +} diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala new file mode 100644 index 0000000000..7b82a4cdd9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.{FileWriter, PrintWriter, File} + +import scala.io.Source + +import com.google.common.io.Files +import org.scalatest.FunSuite +import org.apache.hadoop.io._ +import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec} + + +import SparkContext._ + +class FileSuite extends FunSuite with LocalSparkContext { + + test("text files") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 4) + nums.saveAsTextFile(outputDir) + // Read the plain text file and check it's OK + val outputFile = new File(outputDir, "part-00000") + val content = Source.fromFile(outputFile).mkString + assert(content === "1\n2\n3\n4\n") + // Also try reading it in as a text file RDD + assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) + } + + test("text files (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize("a" * 10000, 1) + data.saveAsTextFile(normalDir) + data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec]) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.textFile(normalDir).collect + assert(normalContent === Array.fill(10000)("a")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.textFile(compressedOutputDir).collect + assert(compressedContent === Array.fill(10000)("a")) + + assert(compressedFile.length < normalFile.length) + } + + test("SequenceFiles") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa) + nums.saveAsSequenceFile(outputDir) + // Try reading the output back as a SequenceFile + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("SequenceFile (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x)) + data.saveAsSequenceFile(normalDir) + data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec])) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.sequenceFile[String, String](normalDir).collect + assert(normalContent === Array.fill(100)("abc", "abc")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect + assert(compressedContent === Array.fill(100)("abc", "abc")) + + assert(compressedFile.length < normalFile.length) + } + + test("SequenceFile with writable key") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x)) + nums.saveAsSequenceFile(outputDir) + // Try reading the output back as a SequenceFile + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("SequenceFile with writable value") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (x, new Text("a" * x))) + nums.saveAsSequenceFile(outputDir) + // Try reading the output back as a SequenceFile + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("SequenceFile with writable key and value") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) + nums.saveAsSequenceFile(outputDir) + // Try reading the output back as a SequenceFile + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("implicit conversions in reading SequenceFiles") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa) + nums.saveAsSequenceFile(outputDir) + // Similar to the tests above, we read a SequenceFile, but this time we pass type params + // that are convertable to Writable instead of calling sequenceFile[IntWritable, Text] + val output1 = sc.sequenceFile[Int, String](outputDir) + assert(output1.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) + // Also try having one type be a subclass of Writable and one not + val output2 = sc.sequenceFile[Int, Text](outputDir) + assert(output2.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + val output3 = sc.sequenceFile[IntWritable, String](outputDir) + assert(output3.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("object files of ints") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 4) + nums.saveAsObjectFile(outputDir) + // Try reading the output back as an object file + val output = sc.objectFile[Int](outputDir) + assert(output.collect().toList === List(1, 2, 3, 4)) + } + + test("object files of complex types") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) + nums.saveAsObjectFile(outputDir) + // Try reading the output back as an object file + val output = sc.objectFile[(Int, String)](outputDir) + assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) + } + + test("write SequenceFile using new Hadoop API") { + import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) + nums.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, Text]]( + outputDir) + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("read SequenceFile using new Hadoop API") { + import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) + nums.saveAsSequenceFile(outputDir) + val output = + sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("file caching") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val out = new FileWriter(tempDir + "/input") + out.write("Hello world!\n") + out.write("What's up?\n") + out.write("Goodbye\n") + out.close() + val rdd = sc.textFile(tempDir + "/input").cache() + assert(rdd.count() === 3) + assert(rdd.count() === 3) + assert(rdd.count() === 3) + } +} diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java new file mode 100644 index 0000000000..8a869c9005 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -0,0 +1,865 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.File; +import java.io.IOException; +import java.io.Serializable; +import java.util.*; + +import com.google.common.base.Optional; +import scala.Tuple2; + +import com.google.common.base.Charsets; +import org.apache.hadoop.io.compress.DefaultCodec; +import com.google.common.io.Files; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.SequenceFileInputFormat; +import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.apache.hadoop.mapreduce.Job; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.partial.BoundedDouble; +import org.apache.spark.partial.PartialResult; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.util.StatCounter; + + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaAPISuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port"); + } + + static class ReverseIntComparator implements Comparator, Serializable { + + @Override + public int compare(Integer a, Integer b) { + if (a > b) return -1; + else if (a < b) return 1; + else return 0; + } + }; + + @Test + public void sparkContextUnion() { + // Union of non-specialized JavaRDDs + List strings = Arrays.asList("Hello", "World"); + JavaRDD s1 = sc.parallelize(strings); + JavaRDD s2 = sc.parallelize(strings); + // Varargs + JavaRDD sUnion = sc.union(s1, s2); + Assert.assertEquals(4, sUnion.count()); + // List + List> list = new ArrayList>(); + list.add(s2); + sUnion = sc.union(s1, list); + Assert.assertEquals(4, sUnion.count()); + + // Union of JavaDoubleRDDs + List doubles = Arrays.asList(1.0, 2.0); + JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles); + JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles); + JavaDoubleRDD dUnion = sc.union(d1, d2); + Assert.assertEquals(4, dUnion.count()); + + // Union of JavaPairRDDs + List> pairs = new ArrayList>(); + pairs.add(new Tuple2(1, 2)); + pairs.add(new Tuple2(3, 4)); + JavaPairRDD p1 = sc.parallelizePairs(pairs); + JavaPairRDD p2 = sc.parallelizePairs(pairs); + JavaPairRDD pUnion = sc.union(p1, p2); + Assert.assertEquals(4, pUnion.count()); + } + + @Test + public void sortByKey() { + List> pairs = new ArrayList>(); + pairs.add(new Tuple2(0, 4)); + pairs.add(new Tuple2(3, 2)); + pairs.add(new Tuple2(-1, 1)); + + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + // Default comparator + JavaPairRDD sortedRDD = rdd.sortByKey(); + Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + List> sortedPairs = sortedRDD.collect(); + Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + + // Custom comparator + sortedRDD = rdd.sortByKey(new ReverseIntComparator(), false); + Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + sortedPairs = sortedRDD.collect(); + Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + } + + static int foreachCalls = 0; + + @Test + public void foreach() { + foreachCalls = 0; + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); + rdd.foreach(new VoidFunction() { + @Override + public void call(String s) { + foreachCalls++; + } + }); + Assert.assertEquals(2, foreachCalls); + } + + @Test + public void lookup() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + Assert.assertEquals(2, categories.lookup("Oranges").size()); + Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size()); + } + + @Test + public void groupBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Function isOdd = new Function() { + @Override + public Boolean call(Integer x) { + return x % 2 == 0; + } + }; + JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); + Assert.assertEquals(2, oddsAndEvens.count()); + Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens + Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds + + oddsAndEvens = rdd.groupBy(isOdd, 1); + Assert.assertEquals(2, oddsAndEvens.count()); + Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens + Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds + } + + @Test + public void cogroup() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD, List>> cogrouped = categories.cogroup(prices); + Assert.assertEquals("[Fruit, Citrus]", cogrouped.lookup("Oranges").get(0)._1().toString()); + Assert.assertEquals("[2]", cogrouped.lookup("Oranges").get(0)._2().toString()); + + cogrouped.collect(); + } + + @Test + public void leftOuterJoin() { + JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( + new Tuple2(1, 1), + new Tuple2(1, 2), + new Tuple2(2, 1), + new Tuple2(3, 1) + )); + JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( + new Tuple2(1, 'x'), + new Tuple2(2, 'y'), + new Tuple2(2, 'z'), + new Tuple2(4, 'w') + )); + List>>> joined = + rdd1.leftOuterJoin(rdd2).collect(); + Assert.assertEquals(5, joined.size()); + Tuple2>> firstUnmatched = + rdd1.leftOuterJoin(rdd2).filter( + new Function>>, Boolean>() { + @Override + public Boolean call(Tuple2>> tup) + throws Exception { + return !tup._2()._2().isPresent(); + } + }).first(); + Assert.assertEquals(3, firstUnmatched._1().intValue()); + } + + @Test + public void foldReduce() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Function2 add = new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + + int sum = rdd.fold(0, add); + Assert.assertEquals(33, sum); + + sum = rdd.reduce(add); + Assert.assertEquals(33, sum); + } + + @Test + public void foldByKey() { + List> pairs = Arrays.asList( + new Tuple2(2, 1), + new Tuple2(2, 1), + new Tuple2(1, 1), + new Tuple2(3, 2), + new Tuple2(3, 1) + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + JavaPairRDD sums = rdd.foldByKey(0, + new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); + Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); + Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); + } + + @Test + public void reduceByKey() { + List> pairs = Arrays.asList( + new Tuple2(2, 1), + new Tuple2(2, 1), + new Tuple2(1, 1), + new Tuple2(3, 2), + new Tuple2(3, 1) + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + JavaPairRDD counts = rdd.reduceByKey( + new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + Assert.assertEquals(1, counts.lookup(1).get(0).intValue()); + Assert.assertEquals(2, counts.lookup(2).get(0).intValue()); + Assert.assertEquals(3, counts.lookup(3).get(0).intValue()); + + Map localCounts = counts.collectAsMap(); + Assert.assertEquals(1, localCounts.get(1).intValue()); + Assert.assertEquals(2, localCounts.get(2).intValue()); + Assert.assertEquals(3, localCounts.get(3).intValue()); + + localCounts = rdd.reduceByKeyLocally(new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + Assert.assertEquals(1, localCounts.get(1).intValue()); + Assert.assertEquals(2, localCounts.get(2).intValue()); + Assert.assertEquals(3, localCounts.get(3).intValue()); + } + + @Test + public void approximateResults() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Map countsByValue = rdd.countByValue(); + Assert.assertEquals(2, countsByValue.get(1).longValue()); + Assert.assertEquals(1, countsByValue.get(13).longValue()); + + PartialResult> approx = rdd.countByValueApprox(1); + Map finalValue = approx.getFinalValue(); + Assert.assertEquals(2.0, finalValue.get(1).mean(), 0.01); + Assert.assertEquals(1.0, finalValue.get(13).mean(), 0.01); + } + + @Test + public void take() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Assert.assertEquals(1, rdd.first().intValue()); + List firstTwo = rdd.take(2); + List sample = rdd.takeSample(false, 2, 42); + } + + @Test + public void cartesian() { + JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); + JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); + JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); + Assert.assertEquals(new Tuple2("Hello", 1.0), cartesian.first()); + } + + @Test + public void javaDoubleRDD() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); + JavaDoubleRDD distinct = rdd.distinct(); + Assert.assertEquals(5, distinct.count()); + JavaDoubleRDD filter = rdd.filter(new Function() { + @Override + public Boolean call(Double x) { + return x > 2.0; + } + }); + Assert.assertEquals(3, filter.count()); + JavaDoubleRDD union = rdd.union(rdd); + Assert.assertEquals(12, union.count()); + union = union.cache(); + Assert.assertEquals(12, union.count()); + + Assert.assertEquals(20, rdd.sum(), 0.01); + StatCounter stats = rdd.stats(); + Assert.assertEquals(20, stats.sum(), 0.01); + Assert.assertEquals(20/6.0, rdd.mean(), 0.01); + Assert.assertEquals(20/6.0, rdd.mean(), 0.01); + Assert.assertEquals(6.22222, rdd.variance(), 0.01); + Assert.assertEquals(7.46667, rdd.sampleVariance(), 0.01); + Assert.assertEquals(2.49444, rdd.stdev(), 0.01); + Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01); + + Double first = rdd.first(); + List take = rdd.take(5); + } + + @Test + public void map() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaDoubleRDD doubles = rdd.map(new DoubleFunction() { + @Override + public Double call(Integer x) { + return 1.0 * x; + } + }).cache(); + JavaPairRDD pairs = rdd.map(new PairFunction() { + @Override + public Tuple2 call(Integer x) { + return new Tuple2(x, x); + } + }).cache(); + JavaRDD strings = rdd.map(new Function() { + @Override + public String call(Integer x) { + return x.toString(); + } + }).cache(); + } + + @Test + public void flatMap() { + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello World!", + "The quick brown fox jumps over the lazy dog.")); + JavaRDD words = rdd.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Arrays.asList(x.split(" ")); + } + }); + Assert.assertEquals("Hello", words.first()); + Assert.assertEquals(11, words.count()); + + JavaPairRDD pairs = rdd.flatMap( + new PairFlatMapFunction() { + + @Override + public Iterable> call(String s) { + List> pairs = new LinkedList>(); + for (String word : s.split(" ")) pairs.add(new Tuple2(word, word)); + return pairs; + } + } + ); + Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); + Assert.assertEquals(11, pairs.count()); + + JavaDoubleRDD doubles = rdd.flatMap(new DoubleFlatMapFunction() { + @Override + public Iterable call(String s) { + List lengths = new LinkedList(); + for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + return lengths; + } + }); + Double x = doubles.first(); + Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01); + Assert.assertEquals(11, pairs.count()); + } + + @Test + public void mapsFromPairsToPairs() { + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMap( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) throws Exception { + return Collections.singletonList(item.swap()); + } + }); + swapped.collect(); + + // There was never a bug here, but it's worth testing: + pairRDD.map(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) throws Exception { + return item.swap(); + } + }).collect(); + } + + @Test + public void mapPartitions() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD partitionSums = rdd.mapPartitions( + new FlatMapFunction, Integer>() { + @Override + public Iterable call(Iterator iter) { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(sum); + } + }); + Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + } + + @Test + public void persist() { + JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); + doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY()); + Assert.assertEquals(20, doubleRDD.sum(), 0.1); + + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); + Assert.assertEquals("a", pairRDD.first()._2()); + + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + rdd = rdd.persist(StorageLevel.DISK_ONLY()); + Assert.assertEquals(1, rdd.first().intValue()); + } + + @Test + public void iterator() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); + TaskContext context = new TaskContext(0, 0, 0, null); + Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); + } + + @Test + public void glom() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + Assert.assertEquals("[1, 2]", rdd.glom().first().toString()); + } + + // File input / output tests are largely adapted from FileSuite: + + @Test + public void textFiles() throws IOException { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + rdd.saveAsTextFile(outputDir); + // Read the plain text file and check it's OK + File outputFile = new File(outputDir, "part-00000"); + String content = Files.toString(outputFile, Charsets.UTF_8); + Assert.assertEquals("1\n2\n3\n4\n", content); + // Also try reading it in as a text file RDD + List expected = Arrays.asList("1", "2", "3", "4"); + JavaRDD readRDD = sc.textFile(outputDir); + Assert.assertEquals(expected, readRDD.collect()); + } + + @Test + public void textFilesCompressed() throws IOException { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + rdd.saveAsTextFile(outputDir, DefaultCodec.class); + + // Try reading it in as a text file RDD + List expected = Arrays.asList("1", "2", "3", "4"); + JavaRDD readRDD = sc.textFile(outputDir); + Assert.assertEquals(expected, readRDD.collect()); + } + + @Test + public void sequenceFile() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + + // Try reading the output back as an object file + JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, + Text.class).map(new PairFunction, Integer, String>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(pair._1().get(), pair._2().toString()); + } + }); + Assert.assertEquals(pairs, readRDD.collect()); + } + + @Test + public void writeWithNewAPIHadoopFile() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + + JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, + Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, + String>() { + @Override + public String call(Tuple2 x) { + return x.toString(); + } + }).collect().toString()); + } + + @Test + public void readWithNewAPIHadoopFile() throws IOException { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + + JavaPairRDD output = sc.newAPIHadoopFile(outputDir, + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, + Text.class, new Job().getConfiguration()); + Assert.assertEquals(pairs.toString(), output.map(new Function, + String>() { + @Override + public String call(Tuple2 x) { + return x.toString(); + } + }).collect().toString()); + } + + @Test + public void objectFilesOfInts() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + rdd.saveAsObjectFile(outputDir); + // Try reading the output back as an object file + List expected = Arrays.asList(1, 2, 3, 4); + JavaRDD readRDD = sc.objectFile(outputDir); + Assert.assertEquals(expected, readRDD.collect()); + } + + @Test + public void objectFilesOfComplexTypes() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + rdd.saveAsObjectFile(outputDir); + // Try reading the output back as an object file + JavaRDD> readRDD = sc.objectFile(outputDir); + Assert.assertEquals(pairs, readRDD.collect()); + } + + @Test + public void hadoopFile() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + + JavaPairRDD output = sc.hadoopFile(outputDir, + SequenceFileInputFormat.class, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, + String>() { + @Override + public String call(Tuple2 x) { + return x.toString(); + } + }).collect().toString()); + } + + @Test + public void hadoopFileCompressed() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, + DefaultCodec.class); + + JavaPairRDD output = sc.hadoopFile(outputDir, + SequenceFileInputFormat.class, IntWritable.class, Text.class); + + Assert.assertEquals(pairs.toString(), output.map(new Function, + String>() { + @Override + public String call(Tuple2 x) { + return x.toString(); + } + }).collect().toString()); + } + + @Test + public void zip() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaDoubleRDD doubles = rdd.map(new DoubleFunction() { + @Override + public Double call(Integer x) { + return 1.0 * x; + } + }); + JavaPairRDD zipped = rdd.zip(doubles); + zipped.count(); + } + + @Test + public void zipPartitions() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); + FlatMapFunction2, Iterator, Integer> sizesFn = + new FlatMapFunction2, Iterator, Integer>() { + @Override + public Iterable call(Iterator i, Iterator s) { + int sizeI = 0; + int sizeS = 0; + while (i.hasNext()) { + sizeI += 1; + i.next(); + } + while (s.hasNext()) { + sizeS += 1; + s.next(); + } + return Arrays.asList(sizeI, sizeS); + } + }; + + JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); + Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); + } + + @Test + public void accumulators() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + final Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + intAccum.add(x); + } + }); + Assert.assertEquals((Integer) 25, intAccum.value()); + + final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + doubleAccum.add((double) x); + } + }); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + public Float addInPlace(Float r, Float t) { + return r + t; + } + + public Float addAccumulator(Float r, Float t) { + return r + t; + } + + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + floatAccum.add((float) x); + } + }); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + + // Test the setValue method + floatAccum.setValue(5.0f); + Assert.assertEquals((Float) 5.0f, floatAccum.value()); + } + + @Test + public void keyBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); + List> s = rdd.keyBy(new Function() { + public String call(Integer t) throws Exception { + return t.toString(); + } + }).collect(); + Assert.assertEquals(new Tuple2("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + } + + @Test + public void checkpointAndComputation() { + File tempDir = Files.createTempDir(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + Assert.assertEquals(false, rdd.isCheckpointed()); + rdd.checkpoint(); + rdd.count(); // Forces the DAG to cause a checkpoint + Assert.assertEquals(true, rdd.isCheckpointed()); + Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); + } + + @Test + public void checkpointAndRestore() { + File tempDir = Files.createTempDir(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + Assert.assertEquals(false, rdd.isCheckpointed()); + rdd.checkpoint(); + rdd.count(); // Forces the DAG to cause a checkpoint + Assert.assertEquals(true, rdd.isCheckpointed()); + + Assert.assertTrue(rdd.getCheckpointFile().isPresent()); + JavaRDD recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); + Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); + } + + @Test + public void mapOnPairRDD() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); + JavaPairRDD rdd2 = rdd1.map(new PairFunction() { + @Override + public Tuple2 call(Integer i) throws Exception { + return new Tuple2(i, i % 2); + } + }); + JavaPairRDD rdd3 = rdd2.map( + new PairFunction, Integer, Integer>() { + @Override + public Tuple2 call(Tuple2 in) throws Exception { + return new Tuple2(in._2(), in._1()); + } + }); + Assert.assertEquals(Arrays.asList( + new Tuple2(1, 1), + new Tuple2(0, 2), + new Tuple2(1, 3), + new Tuple2(0, 4)), rdd3.collect()); + + } +} diff --git a/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala new file mode 100644 index 0000000000..d7b23c93fe --- /dev/null +++ b/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable + +import org.scalatest.FunSuite +import com.esotericsoftware.kryo._ + +import KryoTest._ + +class KryoSerializerSuite extends FunSuite with SharedSparkContext { + test("basic types") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check(1) + check(1L) + check(1.0f) + check(1.0) + check(1.toByte) + check(1.toShort) + check("") + check("hello") + check(Integer.MAX_VALUE) + check(Integer.MIN_VALUE) + check(java.lang.Long.MAX_VALUE) + check(java.lang.Long.MIN_VALUE) + check[String](null) + check(Array(1, 2, 3)) + check(Array(1L, 2L, 3L)) + check(Array(1.0, 2.0, 3.0)) + check(Array(1.0f, 2.9f, 3.9f)) + check(Array("aaa", "bbb", "ccc")) + check(Array("aaa", "bbb", null)) + check(Array(true, false, true)) + check(Array('a', 'b', 'c')) + check(Array[Int]()) + check(Array(Array("1", "2"), Array("1", "2", "3", "4"))) + } + + test("pairs") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check((1, 1)) + check((1, 1L)) + check((1L, 1)) + check((1L, 1L)) + check((1.0, 1)) + check((1, 1.0)) + check((1.0, 1.0)) + check((1.0, 1L)) + check((1L, 1.0)) + check((1.0, 1L)) + check(("x", 1)) + check(("x", 1.0)) + check(("x", 1L)) + check((1, "x")) + check((1.0, "x")) + check((1L, "x")) + check(("x", "x")) + } + + test("Scala data structures") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) + } + + test("custom registrator") { + import KryoTest._ + System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) + + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + + check(CaseClass(17, "hello")) + + val c1 = new ClassWithNoArgConstructor + c1.x = 32 + check(c1) + + val c2 = new ClassWithoutNoArgConstructor(47) + check(c2) + + val hashMap = new java.util.HashMap[String, String] + hashMap.put("foo", "bar") + check(hashMap) + + System.clearProperty("spark.kryo.registrator") + } + + test("kryo with collect") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) + assert(control === result.toSeq) + } + + test("kryo with parallelize") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect() + assert (control === result.toSeq) + } + + test("kryo with parallelize for specialized tuples") { + assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).count === 3) + } + + test("kryo with parallelize for primitive arrays") { + assert (sc.parallelize( Array(1, 2, 3) ).count === 3) + } + + test("kryo with collect for specialized tuples") { + assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11)) + } + + test("kryo with reduce") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(control.sum === result) + } + + // TODO: this still doesn't work + ignore("kryo with fold") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(10 + control.sum === result) + } + + override def beforeAll() { + System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) + super.beforeAll() + } + + override def afterAll() { + super.afterAll() + System.clearProperty("spark.kryo.registrator") + System.clearProperty("spark.serializer") + } +} + +object KryoTest { + case class CaseClass(i: Int, s: String) {} + + class ClassWithNoArgConstructor { + var x: Int = 0 + override def equals(other: Any) = other match { + case c: ClassWithNoArgConstructor => x == c.x + case _ => false + } + } + + class ClassWithoutNoArgConstructor(val x: Int) { + override def equals(other: Any) = other match { + case c: ClassWithoutNoArgConstructor => x == c.x + case _ => false + } + } + + class MyRegistrator extends KryoRegistrator { + override def registerClasses(k: Kryo) { + k.register(classOf[CaseClass]) + k.register(classOf[ClassWithNoArgConstructor]) + k.register(classOf[ClassWithoutNoArgConstructor]) + k.register(classOf[java.util.HashMap[_, _]]) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala new file mode 100644 index 0000000000..6ec124da9c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterEach +import org.scalatest.BeforeAndAfterAll + +import org.jboss.netty.logging.InternalLoggerFactory +import org.jboss.netty.logging.Slf4JLoggerFactory + +/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ +trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => + + @transient var sc: SparkContext = _ + + override def beforeAll() { + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); + super.beforeAll() + } + + override def afterEach() { + resetSparkContext() + super.afterEach() + } + + def resetSparkContext() = { + if (sc != null) { + LocalSparkContext.stop(sc) + sc = null + } + } + +} + +object LocalSparkContext { + def stop(sc: SparkContext) { + sc.stop() + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ + def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { + try { + f(sc) + } finally { + stop(sc) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala new file mode 100644 index 0000000000..6013320eaa --- /dev/null +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import akka.actor._ +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.AkkaUtils + +class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { + + test("compressSize") { + assert(MapOutputTracker.compressSize(0L) === 0) + assert(MapOutputTracker.compressSize(1L) === 1) + assert(MapOutputTracker.compressSize(2L) === 8) + assert(MapOutputTracker.compressSize(10L) === 25) + assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145) + assert((MapOutputTracker.compressSize(1000000000L) & 0xFF) === 218) + // This last size is bigger than we can encode in a byte, so check that we just return 255 + assert((MapOutputTracker.compressSize(1000000000000000000L) & 0xFF) === 255) + } + + test("decompressSize") { + assert(MapOutputTracker.decompressSize(0) === 0) + for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) { + val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size)) + assert(size2 >= 0.99 * size && size2 <= 1.11 * size, + "size " + size + " decompressed to " + size2 + ", which is out of range") + } + } + + test("master start and stop") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTracker() + tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) + tracker.stop() + } + + test("master register and fetch") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTracker() + tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) + tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val compressedSize10000 = MapOutputTracker.compressSize(10000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + val size10000 = MapOutputTracker.decompressSize(compressedSize10000) + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + Array(compressedSize1000, compressedSize10000))) + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + Array(compressedSize10000, compressedSize1000))) + val statuses = tracker.getServerStatuses(10, 0) + assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000), + (BlockManagerId("b", "hostB", 1000, 0), size10000))) + tracker.stop() + } + + test("master register and unregister and fetch") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTracker() + tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) + tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val compressedSize10000 = MapOutputTracker.compressSize(10000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + val size10000 = MapOutputTracker.decompressSize(compressedSize10000) + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + Array(compressedSize1000, compressedSize1000, compressedSize1000))) + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + Array(compressedSize10000, compressedSize1000, compressedSize1000))) + + // As if we had two simulatenous fetch failures + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) + + // The remaining reduce task might try to grab the output despite the shuffle failure; + // this should cause it to fail, and the scheduler will ignore the failure due to the + // stage already being aborted. + intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + } + + test("remote fetch") { + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + val masterTracker = new MapOutputTracker() + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker") + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) + val slaveTracker = new MapOutputTracker() + slaveTracker.trackerActor = slaveSystem.actorFor( + "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker") + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + // failure should be cached + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + } +} diff --git a/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala new file mode 100644 index 0000000000..f79752b34e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet + +import org.scalatest.FunSuite + +import com.google.common.io.Files +import org.apache.spark.SparkContext._ + + +class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { + test("groupByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with duplicates") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with negative key hash codes") { + val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesForMinus1 = groups.find(_._1 == -1).get._2 + assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with many output partitions") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey(10).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("reduceByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with collectAsMap") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collectAsMap() + assert(sums.size === 2) + assert(sums(1) === 7) + assert(sums(2) === 1) + } + + test("reduceByKey with many output partitons") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_, 10).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with partitioner") { + val p = new Partitioner() { + def numPartitions = 2 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) + val sums = pairs.reduceByKey(_+_) + assert(sums.collect().toSet === Set((1, 4), (0, 1))) + assert(sums.partitioner === Some(p)) + // count the dependencies to make sure there is only 1 ShuffledRDD + val deps = new HashSet[RDD[_]]() + def visit(r: RDD[_]) { + for (dep <- r.dependencies) { + deps += dep.rdd + visit(dep.rdd) + } + } + visit(sums) + assert(deps.size === 2) // ShuffledRDD, ParallelCollection + } + + test("join") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("join all-to-all") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) + val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 6) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (1, 'y')), + (1, (2, 'x')), + (1, (2, 'y')), + (1, (3, 'x')), + (1, (3, 'y')) + )) + } + + test("leftOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.leftOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (1, Some('x'))), + (1, (2, Some('x'))), + (2, (1, Some('y'))), + (2, (1, Some('z'))), + (3, (1, None)) + )) + } + + test("rightOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.rightOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (Some(1), 'x')), + (1, (Some(2), 'x')), + (2, (Some(1), 'y')), + (2, (Some(1), 'z')), + (4, (None, 'w')) + )) + } + + test("join with no matches") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 0) + } + + test("join with many output partitions") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2, 10).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("groupWith") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.groupWith(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), + (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), + (3, (ArrayBuffer(1), ArrayBuffer())), + (4, (ArrayBuffer(), ArrayBuffer('w'))) + )) + } + + test("zero-partition RDD") { + val emptyDir = Files.createTempDir() + val file = sc.textFile(emptyDir.getAbsolutePath) + assert(file.partitions.size == 0) + assert(file.collect().toList === Nil) + // Test that a shuffle on the file works, because this used to be a bug + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + } + + test("keys and values") { + val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) + assert(rdd.keys.collect().toList === List(1, 2)) + assert(rdd.values.collect().toList === List("a", "b")) + } + + test("default partitioner uses partition size") { + // specify 2000 partitions + val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) + // do a map, which loses the partitioner + val b = a.map(a => (a, (a * 2).toString)) + // then a group by, and see we didn't revert to 2 partitions + val c = b.groupByKey() + assert(c.partitions.size === 2000) + } + + test("default partitioner uses largest partitioner") { + val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) + val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) + val c = a.join(b) + assert(c.partitions.size === 2000) + } + + test("subtract") { + val a = sc.parallelize(Array(1, 2, 3), 2) + val b = sc.parallelize(Array(2, 3, 4), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set(1)) + assert(c.partitions.size === a.partitions.size) + } + + test("subtract with narrow dependency") { + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set((1, "a"), (3, "c"))) + // Ideally we could keep the original partitioner... + assert(c.partitioner === None) + } + + test("subtractByKey") { + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) + val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitions.size === a.partitions.size) + } + + test("subtractByKey with narrow dependency") { + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitioner.get === p) + } + + test("foldByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.foldByKey(0)(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("foldByKey with mutable result type") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() + // Fold the values using in-place mutation + val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() + assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) + // Check that the mutable objects in the original RDD were not changed + assert(bufs.collect().toSet === Set( + (1, ArrayBuffer(1)), + (1, ArrayBuffer(2)), + (1, ArrayBuffer(3)), + (1, ArrayBuffer(1)), + (2, ArrayBuffer(1)))) + } +} diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala new file mode 100644 index 0000000000..adbe805916 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala @@ -0,0 +1,28 @@ +package org.apache.spark + +import org.scalatest.FunSuite +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.PartitionPruningRDD + + +class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { + + test("Pruned Partitions inherit locality prefs correctly") { + class TestPartition(i: Int) extends Partition { + def index = i + } + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(1), + new TestPartition(2), + new TestPartition(3)) + } + def compute(split: Partition, context: TaskContext) = {Iterator()} + } + val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) + val p = prunedRDD.partitions(0) + assert(p.index == 2) + assert(prunedRDD.partitions.length == 1) + } +} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala new file mode 100644 index 0000000000..7669cf6fb1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import scala.collection.mutable.ArrayBuffer +import SparkContext._ +import org.apache.spark.util.StatCounter +import scala.math.abs + +class PartitioningSuite extends FunSuite with SharedSparkContext { + + test("HashPartitioner equality") { + val p2 = new HashPartitioner(2) + val p4 = new HashPartitioner(4) + val anotherP4 = new HashPartitioner(4) + assert(p2 === p2) + assert(p4 === p4) + assert(p2 != p4) + assert(p4 != p2) + assert(p4 === anotherP4) + assert(anotherP4 === p4) + } + + test("RangePartitioner equality") { + // Make an RDD where all the elements are the same so that the partition range bounds + // are deterministically all the same. + val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x)) + + val p2 = new RangePartitioner(2, rdd) + val p4 = new RangePartitioner(4, rdd) + val anotherP4 = new RangePartitioner(4, rdd) + val descendingP2 = new RangePartitioner(2, rdd, false) + val descendingP4 = new RangePartitioner(4, rdd, false) + + assert(p2 === p2) + assert(p4 === p4) + assert(p2 != p4) + assert(p4 != p2) + assert(p4 === anotherP4) + assert(anotherP4 === p4) + assert(descendingP2 === descendingP2) + assert(descendingP4 === descendingP4) + assert(descendingP2 != descendingP4) + assert(descendingP4 != descendingP2) + assert(p2 != descendingP2) + assert(p4 != descendingP4) + assert(descendingP2 != p2) + assert(descendingP4 != p4) + } + + test("HashPartitioner not equal to RangePartitioner") { + val rdd = sc.parallelize(1 to 10).map(x => (x, x)) + val rangeP2 = new RangePartitioner(2, rdd) + val hashP2 = new HashPartitioner(2) + assert(rangeP2 === rangeP2) + assert(hashP2 === hashP2) + assert(hashP2 != rangeP2) + assert(rangeP2 != hashP2) + } + + test("partitioner preservation") { + val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x)) + + val grouped2 = rdd.groupByKey(2) + val grouped4 = rdd.groupByKey(4) + val reduced2 = rdd.reduceByKey(_ + _, 2) + val reduced4 = rdd.reduceByKey(_ + _, 4) + + assert(rdd.partitioner === None) + + assert(grouped2.partitioner === Some(new HashPartitioner(2))) + assert(grouped4.partitioner === Some(new HashPartitioner(4))) + assert(reduced2.partitioner === Some(new HashPartitioner(2))) + assert(reduced4.partitioner === Some(new HashPartitioner(4))) + + assert(grouped2.groupByKey().partitioner === grouped2.partitioner) + assert(grouped2.groupByKey(3).partitioner != grouped2.partitioner) + assert(grouped2.groupByKey(2).partitioner === grouped2.partitioner) + assert(grouped4.groupByKey().partitioner === grouped4.partitioner) + assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner) + assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner) + + assert(grouped2.join(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner) + + assert(grouped2.join(reduced2).partitioner === grouped2.partitioner) + assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) + assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) + assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) + + assert(grouped2.map(_ => 1).partitioner === None) + assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner) + assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner) + assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner) + } + + test("partitioning Java arrays should fail") { + val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x)) + val arrPairs: RDD[(Array[Int], Int)] = + sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) + + assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) + // We can't catch all usages of arrays, since they might occur inside other collections: + //assert(fails { arrPairs.distinct() }) + assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) + } + + test("zero-length partitions should be correctly handled") { + // Create RDD with some consecutive empty partitions (including the "first" one) + val rdd: RDD[Double] = sc + .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) + .filter(_ >= 0.0) + + // Run the partitions, including the consecutive empty ones, through StatCounter + val stats: StatCounter = rdd.stats(); + assert(abs(6.0 - stats.sum) < 0.01); + assert(abs(6.0/2 - rdd.mean) < 0.01); + assert(abs(1.0 - rdd.variance) < 0.01); + assert(abs(1.0 - rdd.stdev) < 0.01); + + // Add other tests here for classes that should be able to handle empty partitions correctly + } +} diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala new file mode 100644 index 0000000000..2e851d892d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import SparkContext._ + +class PipedRDDSuite extends FunSuite with SharedSparkContext { + + test("basic pipe") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + + val piped = nums.pipe(Seq("cat")) + + val c = piped.collect() + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") + } + + test("advanced pipe") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val bl = sc.broadcast(List("0")) + + val piped = nums.pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Int, f: String=> Unit) => f(i + "_")) + + val c = piped.collect() + + assert(c.size === 8) + assert(c(0) === "0") + assert(c(1) === "\u0001") + assert(c(2) === "1_") + assert(c(3) === "2_") + assert(c(4) === "0") + assert(c(5) === "\u0001") + assert(c(6) === "3_") + assert(c(7) === "4_") + + val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) + val d = nums1.groupBy(str=>str.split("\t")(0)). + pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() + assert(d.size === 8) + assert(d(0) === "0") + assert(d(1) === "\u0001") + assert(d(2) === "b\t2_") + assert(d(3) === "b\t4_") + assert(d(4) === "0") + assert(d(5) === "\u0001") + assert(d(6) === "a\t1_") + assert(d(7) === "a\t3_") + } + + test("pipe with env variable") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) + val c = piped.collect() + assert(c.size === 2) + assert(c(0) === "LALALA") + assert(c(1) === "LALALA") + } + + test("pipe with non-zero exit status") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null")) + intercept[SparkException] { + piped.collect() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/RDDSuite.scala b/core/src/test/scala/org/apache/spark/RDDSuite.scala new file mode 100644 index 0000000000..342ba8adb2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/RDDSuite.scala @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.HashMap +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Span, Millis} +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import scala.collection.parallel.mutable + +class RDDSuite extends FunSuite with SharedSparkContext { + + test("basic operations") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(nums.collect().toList === List(1, 2, 3, 4)) + val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) + assert(dups.distinct().count() === 4) + assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? + assert(dups.distinct.collect === dups.distinct().collect) + assert(dups.distinct(2).collect === dups.distinct().collect) + assert(nums.reduce(_ + _) === 10) + assert(nums.fold(0)(_ + _) === 10) + assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) + assert(nums.filter(_ > 2).collect().toList === List(3, 4)) + assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) + assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) + assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) + val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) + assert(partitionSums.collect().toList === List(3, 7)) + + val partitionSumsWithSplit = nums.mapPartitionsWithSplit { + case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + } + assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) + + val partitionSumsWithIndex = nums.mapPartitionsWithIndex { + case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + } + assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7))) + + intercept[UnsupportedOperationException] { + nums.filter(_ > 5).reduce(_ + _) + } + } + + test("SparkContext.union") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) + assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) + assert(sc.union(Seq(nums)).collect().toList === List(1, 2, 3, 4)) + assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) + } + + test("aggregate") { + val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) + type StringMap = HashMap[String, Int] + val emptyMap = new StringMap { + override def default(key: String): Int = 0 + } + val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => { + map(pair._1) += pair._2 + map + } + val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => { + for ((key, value) <- map2) { + map1(key) += value + } + map1 + } + val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps) + assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) + } + + test("basic caching") { + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + assert(rdd.collect().toList === List(1, 2, 3, 4)) + assert(rdd.collect().toList === List(1, 2, 3, 4)) + assert(rdd.collect().toList === List(1, 2, 3, 4)) + } + + test("caching with failures") { + val onlySplit = new Partition { override def index: Int = 0 } + var shouldFail = true + val rdd = new RDD[Int](sc, Nil) { + override def getPartitions: Array[Partition] = Array(onlySplit) + override val getDependencies = List[Dependency[_]]() + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + if (shouldFail) { + throw new Exception("injected failure") + } else { + return Array(1, 2, 3, 4).iterator + } + } + }.cache() + val thrown = intercept[Exception]{ + rdd.collect() + } + assert(thrown.getMessage.contains("injected failure")) + shouldFail = false + assert(rdd.collect().toList === List(1, 2, 3, 4)) + } + + test("empty RDD") { + val empty = new EmptyRDD[Int](sc) + assert(empty.count === 0) + assert(empty.collect().size === 0) + + val thrown = intercept[UnsupportedOperationException]{ + empty.reduce(_+_) + } + assert(thrown.getMessage.contains("empty")) + + val emptyKv = new EmptyRDD[(Int, Int)](sc) + val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x)) + assert(rdd.join(emptyKv).collect().size === 0) + assert(rdd.rightOuterJoin(emptyKv).collect().size === 0) + assert(rdd.leftOuterJoin(emptyKv).collect().size === 2) + assert(rdd.cogroup(emptyKv).collect().size === 2) + assert(rdd.union(emptyKv).collect().size === 2) + } + + test("cogrouped RDDs") { + val data = sc.parallelize(1 to 10, 10) + + val coalesced1 = data.coalesce(2) + assert(coalesced1.collect().toList === (1 to 10).toList) + assert(coalesced1.glom().collect().map(_.toList).toList === + List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) + + // Check that the narrow dependency is also specified correctly + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === + List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === + List(5, 6, 7, 8, 9)) + + val coalesced2 = data.coalesce(3) + assert(coalesced2.collect().toList === (1 to 10).toList) + assert(coalesced2.glom().collect().map(_.toList).toList === + List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10))) + + val coalesced3 = data.coalesce(10) + assert(coalesced3.collect().toList === (1 to 10).toList) + assert(coalesced3.glom().collect().map(_.toList).toList === + (1 to 10).map(x => List(x)).toList) + + // If we try to coalesce into more partitions than the original RDD, it should just + // keep the original number of partitions. + val coalesced4 = data.coalesce(20) + assert(coalesced4.collect().toList === (1 to 10).toList) + assert(coalesced4.glom().collect().map(_.toList).toList === + (1 to 10).map(x => List(x)).toList) + + // we can optionally shuffle to keep the upstream parallel + val coalesced5 = data.coalesce(1, shuffle = true) + assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _, _]] != + null) + } + test("cogrouped RDDs with locality") { + val data3 = sc.makeRDD(List((1,List("a","c")), (2,List("a","b","c")), (3,List("b")))) + val coal3 = data3.coalesce(3) + val list3 = coal3.partitions.map(p => p.asInstanceOf[CoalescedRDDPartition].preferredLocation) + assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") + + // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 + val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)}))) + val coalesced1 = data.coalesce(3) + assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") + + val splits = coalesced1.glom().collect().map(_.toList).toList + assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length) + + assert(splits.forall(_.length >= 1) === true, "Some partitions were empty") + + // If we try to coalesce into more partitions than the original RDD, it should just + // keep the original number of partitions. + val coalesced4 = data.coalesce(20) + val listOfLists = coalesced4.glom().collect().map(_.toList).toList + val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } + assert( sortedList === (1 to 9). + map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") + } + + test("cogrouped RDDs with locality, large scale (10K partitions)") { + // large scale experiment + import collection.mutable + val rnd = scala.util.Random + val partitions = 10000 + val numMachines = 50 + val machines = mutable.ListBuffer[String]() + (1 to numMachines).foreach(machines += "m"+_) + + val blocks = (1 to partitions).map(i => + { (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) } ) + + val data2 = sc.makeRDD(blocks) + val coalesced2 = data2.coalesce(numMachines*2) + + // test that you get over 90% locality in each group + val minLocality = coalesced2.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) + .foldLeft(1.)((perc, loc) => math.min(perc,loc)) + assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%") + + // test that the groups are load balanced with 100 +/- 20 elements in each + val maxImbalance = coalesced2.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) + .foldLeft(0)((dev, curr) => math.max(math.abs(100-curr),dev)) + assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance) + + val data3 = sc.makeRDD(blocks).map(i => i*2) // derived RDD to test *current* pref locs + val coalesced3 = data3.coalesce(numMachines*2) + val minLocality2 = coalesced3.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) + .foldLeft(1.)((perc, loc) => math.min(perc,loc)) + assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + + (minLocality2*100.).toInt + "%") + } + + test("zipped RDDs") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val zipped = nums.zip(nums.map(_ + 1.0)) + assert(zipped.glom().map(_.toList).collect().toList === + List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) + + intercept[IllegalArgumentException] { + nums.zip(sc.parallelize(1 to 4, 1)).collect() + } + } + + test("partition pruning") { + val data = sc.parallelize(1 to 10, 10) + // Note that split number starts from 0, so > 8 means only 10th partition left. + val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) + assert(prunedRdd.partitions.size === 1) + val prunedData = prunedRdd.collect() + assert(prunedData.size === 1) + assert(prunedData(0) === 10) + } + + test("mapWith") { + import java.util.Random + val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + val randoms = ones.mapWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => prng.nextDouble * t}.collect() + val prn42_3 = { + val prng42 = new Random(42) + prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() + } + val prn43_3 = { + val prng43 = new Random(43) + prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() + } + assert(randoms(2) === prn42_3) + assert(randoms(5) === prn43_3) + } + + test("flatMapWith") { + import java.util.Random + val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + val randoms = ones.flatMapWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => + val random = prng.nextDouble() + Seq(random * t, random * t * 10)}. + collect() + val prn42_3 = { + val prng42 = new Random(42) + prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() + } + val prn43_3 = { + val prng43 = new Random(43) + prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() + } + assert(randoms(5) === prn42_3 * 10) + assert(randoms(11) === prn43_3 * 10) + } + + test("filterWith") { + import java.util.Random + val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) + val sample = ints.filterWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => prng.nextInt(3) == 0}. + collect() + val checkSample = { + val prng42 = new Random(42) + val prng43 = new Random(43) + Array(1, 2, 3, 4, 5, 6).filter{i => + if (i < 4) 0 == prng42.nextInt(3) + else 0 == prng43.nextInt(3)} + } + assert(sample.size === checkSample.size) + for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) + } + + test("top with predefined ordering") { + val nums = Array.range(1, 100000) + val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) + val topK = ints.top(5) + assert(topK.size === 5) + assert(topK === nums.reverse.take(5)) + } + + test("top with custom ordering") { + val words = Vector("a", "b", "c", "d") + implicit val ord = implicitly[Ordering[String]].reverse + val rdd = sc.makeRDD(words, 2) + val topK = rdd.top(2) + assert(topK.size === 2) + assert(topK.sorted === Array("b", "a")) + } + + test("takeOrdered with predefined ordering") { + val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + val rdd = sc.makeRDD(nums, 2) + val sortedLowerK = rdd.takeOrdered(5) + assert(sortedLowerK.size === 5) + assert(sortedLowerK === Array(1, 2, 3, 4, 5)) + } + + test("takeOrdered with custom ordering") { + val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + implicit val ord = implicitly[Ordering[Int]].reverse + val rdd = sc.makeRDD(nums, 2) + val sortedTopK = rdd.takeOrdered(5) + assert(sortedTopK.size === 5) + assert(sortedTopK === Array(10, 9, 8, 7, 6)) + assert(sortedTopK === nums.sorted(ord).take(5)) + } + + test("takeSample") { + val data = sc.parallelize(1 to 100, 2) + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=false, 20, seed) + assert(sample.size === 20) // Got exactly 20 elements + assert(sample.toSet.size === 20) // Elements are distinct + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=false, 200, seed) + assert(sample.size === 100) // Got only 100 elements + assert(sample.toSet.size === 100) // Elements are distinct + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=true, 20, seed) + assert(sample.size === 20) // Got exactly 20 elements + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=true, 100, seed) + assert(sample.size === 100) // Got exactly 100 elements + // Chance of getting all distinct elements is astronomically low, so test we got < 100 + assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=true, 200, seed) + assert(sample.size === 200) // Got exactly 200 elements + // Chance of getting all distinct elements is still quite low, so test we got < 100 + assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + } + } + + test("runJob on an invalid partition") { + intercept[IllegalArgumentException] { + sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala new file mode 100644 index 0000000000..97cbca09bf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterAll + +/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ +trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => + + @transient private var _sc: SparkContext = _ + + def sc: SparkContext = _sc + + override def beforeAll() { + _sc = new SparkContext("local", "test") + super.beforeAll() + } + + override def afterAll() { + if (_sc != null) { + LocalSparkContext.stop(_sc) + _sc = null + } + super.afterAll() + } +} diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala new file mode 100644 index 0000000000..e121b162ad --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.BeforeAndAfterAll + + +class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. + + override def beforeAll(configMap: Map[String, Any]) { + System.setProperty("spark.shuffle.use.netty", "true") + } + + override def afterAll(configMap: Map[String, Any]) { + System.setProperty("spark.shuffle.use.netty", "false") + } +} diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala new file mode 100644 index 0000000000..357175e89e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.SparkContext._ +import org.apache.spark.ShuffleSuite.NonJavaSerializableClass +import org.apache.spark.rdd.{SubtractedRDD, CoGroupedRDD, OrderedRDDFunctions, ShuffledRDD} +import org.apache.spark.util.MutablePair + + +class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { + test("groupByKey without compression") { + try { + System.setProperty("spark.shuffle.compress", "false") + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) + val groups = pairs.groupByKey(4).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } finally { + System.setProperty("spark.shuffle.compress", "true") + } + } + + test("shuffle non-zero block size") { + sc = new SparkContext("local-cluster[2,1,512]", "test") + val NUM_BLOCKS = 3 + + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (x, new NonJavaSerializableClass(x * 2)) + } + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( + b, new HashPartitioner(NUM_BLOCKS)).setSerializer(classOf[KryoSerializer].getName) + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + + assert(c.count === 10) + + // All blocks must have non-zero size + (0 until NUM_BLOCKS).foreach { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + assert(statuses.forall(s => s._2 > 0)) + } + } + + test("shuffle serializer") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (x, new NonJavaSerializableClass(x * 2)) + } + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( + b, new HashPartitioner(3)).setSerializer(classOf[KryoSerializer].getName) + assert(c.count === 10) + } + + test("zero sized blocks") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer doesn't create zero-sized blocks. + // So, use Kryo + val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) + .setSerializer(classOf[KryoSerializer].getName) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + + test("zero sized blocks without kryo") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer should create zero-sized blocks + val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + + test("shuffle using mutable pairs") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) + val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) + val results = new ShuffledRDD[Int, Int, MutablePair[Int, Int]](pairs, new HashPartitioner(2)) + .collect() + + data.foreach { pair => results should contain (pair) } + } + + test("sorting using mutable pairs") { + // This is not in SortingSuite because of the local cluster setup. + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) + val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) + val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) + .sortByKey().collect() + results(0) should be (p(1, 11)) + results(1) should be (p(2, 22)) + results(2) should be (p(3, 33)) + results(3) should be (p(100, 100)) + } + + test("cogroup using mutable pairs") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) + val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) + val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) + val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) + val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)).collectAsMap() + + assert(results(1)(0).length === 3) + assert(results(1)(0).contains(1)) + assert(results(1)(0).contains(2)) + assert(results(1)(0).contains(3)) + assert(results(1)(1).length === 2) + assert(results(1)(1).contains("11")) + assert(results(1)(1).contains("12")) + assert(results(2)(0).length === 1) + assert(results(2)(0).contains(1)) + assert(results(2)(1).length === 1) + assert(results(2)(1).contains("22")) + assert(results(3)(0).length === 0) + assert(results(3)(1).contains("3")) + } + + test("subtract mutable pairs") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) + val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) + val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) + val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) + val results = new SubtractedRDD(pairs1, pairs2, new HashPartitioner(2)).collect() + results should have length (1) + // substracted rdd return results as Tuple2 + results(0) should be ((3, 33)) + } +} + +object ShuffleSuite { + + def mergeCombineException(x: Int, y: Int): Int = { + throw new SparkException("Exception for map-side combine.") + x + y + } + + class NonJavaSerializableClass(val value: Int) +} diff --git a/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala new file mode 100644 index 0000000000..214ac74898 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfterAll +import org.scalatest.PrivateMethodTester + +class DummyClass1 {} + +class DummyClass2 { + val x: Int = 0 +} + +class DummyClass3 { + val x: Int = 0 + val y: Double = 0.0 +} + +class DummyClass4(val d: DummyClass3) { + val x: Int = 0 +} + +object DummyString { + def apply(str: String) : DummyString = new DummyString(str.toArray) +} +class DummyString(val arr: Array[Char]) { + override val hashCode: Int = 0 + // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f + @transient val hash32: Int = 0 +} + +class SizeEstimatorSuite + extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { + + var oldArch: String = _ + var oldOops: String = _ + + override def beforeAll() { + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + oldArch = System.setProperty("os.arch", "amd64") + oldOops = System.setProperty("spark.test.useCompressedOops", "true") + } + + override def afterAll() { + resetOrClear("os.arch", oldArch) + resetOrClear("spark.test.useCompressedOops", oldOops) + } + + test("simple classes") { + assert(SizeEstimator.estimate(new DummyClass1) === 16) + assert(SizeEstimator.estimate(new DummyClass2) === 16) + assert(SizeEstimator.estimate(new DummyClass3) === 24) + assert(SizeEstimator.estimate(new DummyClass4(null)) === 24) + assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48) + } + + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. + test("strings") { + assert(SizeEstimator.estimate(DummyString("")) === 40) + assert(SizeEstimator.estimate(DummyString("a")) === 48) + assert(SizeEstimator.estimate(DummyString("ab")) === 48) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) + } + + test("primitive arrays") { + assert(SizeEstimator.estimate(new Array[Byte](10)) === 32) + assert(SizeEstimator.estimate(new Array[Char](10)) === 40) + assert(SizeEstimator.estimate(new Array[Short](10)) === 40) + assert(SizeEstimator.estimate(new Array[Int](10)) === 56) + assert(SizeEstimator.estimate(new Array[Long](10)) === 96) + assert(SizeEstimator.estimate(new Array[Float](10)) === 56) + assert(SizeEstimator.estimate(new Array[Double](10)) === 96) + assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016) + assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016) + } + + test("object arrays") { + // Arrays containing nulls should just have one pointer per element + assert(SizeEstimator.estimate(new Array[String](10)) === 56) + assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56) + + // For object arrays with non-null elements, each object should take one pointer plus + // however many bytes that class takes. (Note that Array.fill calls the code in its + // second parameter separately for each object, so we get distinct objects.) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296) + assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56) + + // Past size 100, our samples 100 elements, but we should still get the right size. + assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016) + + // If an array contains the *same* element many times, we should only count it once. + val d1 = new DummyClass1 + assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object + assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object + + // Same thing with huge array containing the same element many times. Note that this won't + // return exactly 4032 because it can't tell that *all* the elements will equal the first + // one it samples, but it should be close to that. + + // TODO: If we sample 100 elements, this should always be 4176 ? + val estimatedSize = SizeEstimator.estimate(Array.fill(1000)(d1)) + assert(estimatedSize >= 4000, "Estimated size " + estimatedSize + " should be more than 4000") + assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4100") + } + + test("32-bit arch") { + val arch = System.setProperty("os.arch", "x86") + + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + + assert(SizeEstimator.estimate(DummyString("")) === 40) + assert(SizeEstimator.estimate(DummyString("a")) === 48) + assert(SizeEstimator.estimate(DummyString("ab")) === 48) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) + + resetOrClear("os.arch", arch) + } + + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. + test("64-bit arch with no compressed oops") { + val arch = System.setProperty("os.arch", "amd64") + val oops = System.setProperty("spark.test.useCompressedOops", "false") + + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + + assert(SizeEstimator.estimate(DummyString("")) === 56) + assert(SizeEstimator.estimate(DummyString("a")) === 64) + assert(SizeEstimator.estimate(DummyString("ab")) === 64) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72) + + resetOrClear("os.arch", arch) + resetOrClear("spark.test.useCompressedOops", oops) + } + + def resetOrClear(prop: String, oldValue: String) { + if (oldValue != null) { + System.setProperty(prop, oldValue) + } else { + System.clearProperty(prop) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/SortingSuite.scala b/core/src/test/scala/org/apache/spark/SortingSuite.scala new file mode 100644 index 0000000000..f4fa9511dd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SortingSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.ShouldMatchers +import SparkContext._ + +class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging { + + test("sortByKey") { + val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) + assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) + } + + test("large array") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey() + assert(sorted.partitions.size === 2) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + + test("large array with one split") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey(true, 1) + assert(sorted.partitions.size === 1) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + + test("large array with many partitions") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey(true, 20) + assert(sorted.partitions.size === 20) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + + test("sort descending") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + + test("sort descending with one split") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 1) + assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + + test("sort descending with many partitions") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + + test("more partitions than elements") { + val rand = new scala.util.Random() + val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 30) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + } + + test("empty RDD") { + val pairArr = new Array[(Int, Int)](0) + val pairs = sc.parallelize(pairArr, 2) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + } + + test("partition balancing") { + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 4).sortByKey() + assert(sorted.collect() === pairArr.sortBy(_._1)) + val partitions = sorted.collectPartitions() + logInfo("Partition lengths: " + partitions.map(_.length).mkString(", ")) + partitions(0).length should be > 180 + partitions(1).length should be > 180 + partitions(2).length should be > 180 + partitions(3).length should be > 180 + partitions(0).last should be < partitions(1).head + partitions(1).last should be < partitions(2).head + partitions(2).last should be < partitions(3).head + } + + test("partition balancing for descending sort") { + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 4).sortByKey(false) + assert(sorted.collect() === pairArr.sortBy(_._1).reverse) + val partitions = sorted.collectPartitions() + logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) + partitions(0).length should be > 180 + partitions(1).length should be > 180 + partitions(2).length should be > 180 + partitions(3).length should be > 180 + partitions(0).last should be > partitions(1).head + partitions(1).last should be > partitions(2).head + partitions(2).last should be > partitions(3).head + } +} + diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala new file mode 100644 index 0000000000..939fe51801 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.apache.spark.SparkContext._ + +class SparkContextInfoSuite extends FunSuite with LocalSparkContext { + test("getPersistentRDDs only returns RDDs that are marked as cached") { + sc = new SparkContext("local", "test") + assert(sc.getPersistentRDDs.isEmpty === true) + + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(sc.getPersistentRDDs.isEmpty === true) + + rdd.cache() + assert(sc.getPersistentRDDs.size === 1) + assert(sc.getPersistentRDDs.values.head === rdd) + } + + test("getPersistentRDDs returns an immutable map") { + sc = new SparkContext("local", "test") + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + + val myRdds = sc.getPersistentRDDs + assert(myRdds.size === 1) + assert(myRdds.values.head === rdd1) + + val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() + + // getPersistentRDDs should have 2 RDDs, but myRdds should not change + assert(sc.getPersistentRDDs.size === 2) + assert(myRdds.size === 1) + } + + test("getRDDStorageInfo only reports on RDDs that actually persist data") { + sc = new SparkContext("local", "test") + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + + assert(sc.getRDDStorageInfo.size === 0) + + rdd.collect() + assert(sc.getRDDStorageInfo.size === 1) + } +} diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala new file mode 100644 index 0000000000..69383ddfb8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import SparkContext._ + +/** + * Holds state shared across task threads in some ThreadingSuite tests. + */ +object ThreadingSuiteState { + val runningThreads = new AtomicInteger + val failed = new AtomicBoolean + + def clear() { + runningThreads.set(0) + failed.set(false) + } +} + +class ThreadingSuite extends FunSuite with LocalSparkContext { + + test("accessing SparkContext form a different thread") { + sc = new SparkContext("local", "test") + val nums = sc.parallelize(1 to 10, 2) + val sem = new Semaphore(0) + @volatile var answer1: Int = 0 + @volatile var answer2: Int = 0 + new Thread { + override def run() { + answer1 = nums.reduce(_ + _) + answer2 = nums.first() // This will run "locally" in the current thread + sem.release() + } + }.start() + sem.acquire() + assert(answer1 === 55) + assert(answer2 === 1) + } + + test("accessing SparkContext form multiple threads") { + sc = new SparkContext("local", "test") + val nums = sc.parallelize(1 to 10, 2) + val sem = new Semaphore(0) + @volatile var ok = true + for (i <- 0 until 10) { + new Thread { + override def run() { + val answer1 = nums.reduce(_ + _) + if (answer1 != 55) { + printf("In thread %d: answer1 was %d\n", i, answer1) + ok = false + } + val answer2 = nums.first() // This will run "locally" in the current thread + if (answer2 != 1) { + printf("In thread %d: answer2 was %d\n", i, answer2) + ok = false + } + sem.release() + } + }.start() + } + sem.acquire(10) + if (!ok) { + fail("One or more threads got the wrong answer from an RDD operation") + } + } + + test("accessing multi-threaded SparkContext form multiple threads") { + sc = new SparkContext("local[4]", "test") + val nums = sc.parallelize(1 to 10, 2) + val sem = new Semaphore(0) + @volatile var ok = true + for (i <- 0 until 10) { + new Thread { + override def run() { + val answer1 = nums.reduce(_ + _) + if (answer1 != 55) { + printf("In thread %d: answer1 was %d\n", i, answer1) + ok = false + } + val answer2 = nums.first() // This will run "locally" in the current thread + if (answer2 != 1) { + printf("In thread %d: answer2 was %d\n", i, answer2) + ok = false + } + sem.release() + } + }.start() + } + sem.acquire(10) + if (!ok) { + fail("One or more threads got the wrong answer from an RDD operation") + } + } + + test("parallel job execution") { + // This test launches two jobs with two threads each on a 4-core local cluster. Each thread + // waits until there are 4 threads running at once, to test that both jobs have been launched. + sc = new SparkContext("local[4]", "test") + val nums = sc.parallelize(1 to 2, 2) + val sem = new Semaphore(0) + ThreadingSuiteState.clear() + for (i <- 0 until 2) { + new Thread { + override def run() { + val ans = nums.map(number => { + val running = ThreadingSuiteState.runningThreads + running.getAndIncrement() + val time = System.currentTimeMillis() + while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { + Thread.sleep(100) + } + if (running.get() != 4) { + println("Waited 1 second without seeing runningThreads = 4 (it was " + + running.get() + "); failing test") + ThreadingSuiteState.failed.set(true) + } + number + }).collect() + assert(ans.toList === List(1, 2)) + sem.release() + } + }.start() + } + sem.acquire(2) + if (ThreadingSuiteState.failed.get()) { + fail("One or more threads didn't see runningThreads = 4") + } + } +} diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala new file mode 100644 index 0000000000..46a2da1724 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Span, Millis} +import org.apache.spark.SparkContext._ + +class UnpersistSuite extends FunSuite with LocalSparkContext { + test("unpersist RDD") { + sc = new SparkContext("local", "test") + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + rdd.count + assert(sc.persistentRdds.isEmpty === false) + rdd.unpersist() + assert(sc.persistentRdds.isEmpty === true) + + failAfter(Span(3000, Millis)) { + try { + while (! sc.getRDDStorageInfo.isEmpty) { + Thread.sleep(200) + } + } catch { + case _ => { Thread.sleep(10) } + // Do nothing. We might see exceptions because block manager + // is racing this thread to remove entries from the driver. + } + } + assert(sc.getRDDStorageInfo.isEmpty === true) + } +} diff --git a/core/src/test/scala/org/apache/spark/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/UtilsSuite.scala new file mode 100644 index 0000000000..3a908720a8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/UtilsSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import com.google.common.base.Charsets +import com.google.common.io.Files +import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File} +import org.scalatest.FunSuite +import org.apache.commons.io.FileUtils +import scala.util.Random + +class UtilsSuite extends FunSuite { + + test("bytesToString") { + assert(Utils.bytesToString(10) === "10.0 B") + assert(Utils.bytesToString(1500) === "1500.0 B") + assert(Utils.bytesToString(2000000) === "1953.1 KB") + assert(Utils.bytesToString(2097152) === "2.0 MB") + assert(Utils.bytesToString(2306867) === "2.2 MB") + assert(Utils.bytesToString(5368709120L) === "5.0 GB") + assert(Utils.bytesToString(5L * 1024L * 1024L * 1024L * 1024L) === "5.0 TB") + } + + test("copyStream") { + //input array initialization + val bytes = Array.ofDim[Byte](9000) + Random.nextBytes(bytes) + + val os = new ByteArrayOutputStream() + Utils.copyStream(new ByteArrayInputStream(bytes), os) + + assert(os.toByteArray.toList.equals(bytes.toList)) + } + + test("memoryStringToMb") { + assert(Utils.memoryStringToMb("1") === 0) + assert(Utils.memoryStringToMb("1048575") === 0) + assert(Utils.memoryStringToMb("3145728") === 3) + + assert(Utils.memoryStringToMb("1024k") === 1) + assert(Utils.memoryStringToMb("5000k") === 4) + assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K")) + + assert(Utils.memoryStringToMb("1024m") === 1024) + assert(Utils.memoryStringToMb("5000m") === 5000) + assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M")) + + assert(Utils.memoryStringToMb("2g") === 2048) + assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G")) + + assert(Utils.memoryStringToMb("2t") === 2097152) + assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T")) + } + + test("splitCommandString") { + assert(Utils.splitCommandString("") === Seq()) + assert(Utils.splitCommandString("a") === Seq("a")) + assert(Utils.splitCommandString("aaa") === Seq("aaa")) + assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c")) + assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c")) + assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c")) + assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d")) + assert(Utils.splitCommandString("'b c'") === Seq("b c")) + assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c")) + assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d")) + assert(Utils.splitCommandString("\"b c\"") === Seq("b c")) + assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e")) + assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d")) + assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c")) + assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c")) + assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c")) + assert(Utils.splitCommandString("'a'b") === Seq("ab")) + assert(Utils.splitCommandString("'a''b'") === Seq("ab")) + assert(Utils.splitCommandString("\"a\"b") === Seq("ab")) + assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab")) + assert(Utils.splitCommandString("''") === Seq("")) + assert(Utils.splitCommandString("\"\"") === Seq("")) + } + + test("string formatting of time durations") { + val second = 1000 + val minute = second * 60 + val hour = minute * 60 + def str = Utils.msDurationToString(_) + + assert(str(123) === "123 ms") + assert(str(second) === "1.0 s") + assert(str(second + 462) === "1.5 s") + assert(str(hour) === "1.00 h") + assert(str(minute) === "1.0 m") + assert(str(minute + 4 * second + 34) === "1.1 m") + assert(str(10 * hour + minute + 4 * second) === "10.02 h") + assert(str(10 * hour + 59 * minute + 59 * second + 999) === "11.00 h") + } + + test("reading offset bytes of a file") { + val tmpDir2 = Files.createTempDir() + val f1Path = tmpDir2 + "/f1" + val f1 = new FileOutputStream(f1Path) + f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(Charsets.UTF_8)) + f1.close() + + // Read first few bytes + assert(Utils.offsetBytes(f1Path, 0, 5) === "1\n2\n3") + + // Read some middle bytes + assert(Utils.offsetBytes(f1Path, 4, 11) === "3\n4\n5\n6") + + // Read last few bytes + assert(Utils.offsetBytes(f1Path, 12, 18) === "7\n8\n9\n") + + // Read some nonexistent bytes in the beginning + assert(Utils.offsetBytes(f1Path, -5, 5) === "1\n2\n3") + + // Read some nonexistent bytes at the end + assert(Utils.offsetBytes(f1Path, 12, 22) === "7\n8\n9\n") + + // Read some nonexistent bytes on both ends + assert(Utils.offsetBytes(f1Path, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") + + FileUtils.deleteDirectory(tmpDir2) + } +} + diff --git a/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala new file mode 100644 index 0000000000..618b9c113b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.immutable.NumericRange + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import SparkContext._ + + +object ZippedPartitionsSuite { + def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { + Iterator(i.toArray.size, s.toArray.size, d.toArray.size) + } +} + +class ZippedPartitionsSuite extends FunSuite with SharedSparkContext { + test("print sizes") { + val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) + val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) + val data3 = sc.makeRDD(Array(1.0, 2.0), 2) + + val zippedRDD = data1.zipPartitions(data2, data3)(ZippedPartitionsSuite.procZippedData) + + val obtainedSizes = zippedRDD.collect() + val expectedSizes = Array(2, 3, 1, 2, 3, 1) + assert(obtainedSizes.size == 6) + assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2)) + } +} diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala new file mode 100644 index 0000000000..fd6f69041a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.io + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.scalatest.FunSuite + + +class CompressionCodecSuite extends FunSuite { + + def testCodec(codec: CompressionCodec) { + // Write 1000 integers to the output stream, compressed. + val outputStream = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(outputStream) + for (i <- 1 until 1000) { + out.write(i % 256) + } + out.close() + + // Read the 1000 integers back. + val inputStream = new ByteArrayInputStream(outputStream.toByteArray) + val in = codec.compressedInputStream(inputStream) + for (i <- 1 until 1000) { + assert(in.read() === i % 256) + } + in.close() + } + + test("default compression codec") { + val codec = CompressionCodec.createCodec() + assert(codec.getClass === classOf[SnappyCompressionCodec]) + testCodec(codec) + } + + test("lzf compression codec") { + val codec = CompressionCodec.createCodec(classOf[LZFCompressionCodec].getName) + assert(codec.getClass === classOf[LZFCompressionCodec]) + testCodec(codec) + } + + test("snappy compression codec") { + val codec = CompressionCodec.createCodec(classOf[SnappyCompressionCodec].getName) + assert(codec.getClass === classOf[SnappyCompressionCodec]) + testCodec(codec) + } +} diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala new file mode 100644 index 0000000000..58c94a162d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics + +import org.scalatest.{BeforeAndAfter, FunSuite} + +class MetricsConfigSuite extends FunSuite with BeforeAndAfter { + var filePath: String = _ + + before { + filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile() + } + + test("MetricsConfig with default properties") { + val conf = new MetricsConfig(Option("dummy-file")) + conf.initialize() + + assert(conf.properties.size() === 5) + assert(conf.properties.getProperty("test-for-dummy") === null) + + val property = conf.getInstance("random") + assert(property.size() === 3) + assert(property.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(property.getProperty("sink.servlet.uri") === "/metrics/json") + assert(property.getProperty("sink.servlet.sample") === "false") + } + + test("MetricsConfig with properties set") { + val conf = new MetricsConfig(Option(filePath)) + conf.initialize() + + val masterProp = conf.getInstance("master") + assert(masterProp.size() === 6) + assert(masterProp.getProperty("sink.console.period") === "20") + assert(masterProp.getProperty("sink.console.unit") === "minutes") + assert(masterProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") + assert(masterProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(masterProp.getProperty("sink.servlet.uri") === "/metrics/master/json") + assert(masterProp.getProperty("sink.servlet.sample") === "false") + + val workerProp = conf.getInstance("worker") + assert(workerProp.size() === 6) + assert(workerProp.getProperty("sink.console.period") === "10") + assert(workerProp.getProperty("sink.console.unit") === "seconds") + assert(workerProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") + assert(workerProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(workerProp.getProperty("sink.servlet.uri") === "/metrics/json") + assert(workerProp.getProperty("sink.servlet.sample") === "false") + } + + test("MetricsConfig with subProperties") { + val conf = new MetricsConfig(Option(filePath)) + conf.initialize() + + val propCategories = conf.propertyCategories + assert(propCategories.size === 3) + + val masterProp = conf.getInstance("master") + val sourceProps = conf.subProperties(masterProp, MetricsSystem.SOURCE_REGEX) + assert(sourceProps.size === 1) + assert(sourceProps("jvm").getProperty("class") === "org.apache.spark.metrics.source.JvmSource") + + val sinkProps = conf.subProperties(masterProp, MetricsSystem.SINK_REGEX) + assert(sinkProps.size === 2) + assert(sinkProps.contains("console")) + assert(sinkProps.contains("servlet")) + + val consoleProps = sinkProps("console") + assert(consoleProps.size() === 2) + + val servletProps = sinkProps("servlet") + assert(servletProps.size() === 3) + } +} diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala new file mode 100644 index 0000000000..7181333adf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.deploy.master.MasterSource + +class MetricsSystemSuite extends FunSuite with BeforeAndAfter { + var filePath: String = _ + + before { + filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() + System.setProperty("spark.metrics.conf", filePath) + } + + test("MetricsSystem with default config") { + val metricsSystem = MetricsSystem.createMetricsSystem("default") + val sources = metricsSystem.sources + val sinks = metricsSystem.sinks + + assert(sources.length === 0) + assert(sinks.length === 0) + assert(!metricsSystem.getServletHandlers.isEmpty) + } + + test("MetricsSystem with sources add") { + val metricsSystem = MetricsSystem.createMetricsSystem("test") + val sources = metricsSystem.sources + val sinks = metricsSystem.sinks + + assert(sources.length === 0) + assert(sinks.length === 1) + assert(!metricsSystem.getServletHandlers.isEmpty) + + val source = new MasterSource(null) + metricsSystem.registerSource(source) + assert(sources.length === 1) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala new file mode 100644 index 0000000000..3d39a31252 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.{ BeforeAndAfter, FunSuite } +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.JdbcRDD +import java.sql._ + +class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + before { + Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") + try { + val create = conn.createStatement + create.execute(""" + CREATE TABLE FOO( + ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), + DATA INTEGER + )""") + create.close + val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") + (1 to 100).foreach { i => + insert.setInt(1, i * 2) + insert.executeUpdate + } + insert.close + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => + // table exists + } finally { + conn.close + } + } + + test("basic functionality") { + sc = new SparkContext("local", "test") + val rdd = new JdbcRDD( + sc, + () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, + "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", + 1, 100, 3, + (r: ResultSet) => { r.getInt(1) } ).cache + + assert(rdd.count === 100) + assert(rdd.reduce(_+_) === 10100) + } + + after { + try { + DriverManager.getConnection("jdbc:derby:;shutdown=true") + } catch { + case se: SQLException if se.getSQLState == "XJ015" => + // normal shutdown + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala new file mode 100644 index 0000000000..a80afdee7e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.immutable.NumericRange + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +class ParallelCollectionSplitSuite extends FunSuite with Checkers { + test("one element per slice") { + val data = Array(1, 2, 3) + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === "1") + assert(slices(1).mkString(",") === "2") + assert(slices(2).mkString(",") === "3") + } + + test("one slice") { + val data = Array(1, 2, 3) + val slices = ParallelCollectionRDD.slice(data, 1) + assert(slices.size === 1) + assert(slices(0).mkString(",") === "1,2,3") + } + + test("equal slices") { + val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9) + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === "1,2,3") + assert(slices(1).mkString(",") === "4,5,6") + assert(slices(2).mkString(",") === "7,8,9") + } + + test("non-equal slices") { + val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === "1,2,3") + assert(slices(1).mkString(",") === "4,5,6") + assert(slices(2).mkString(",") === "7,8,9,10") + } + + test("splitting exclusive range") { + val data = 0 until 100 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === (0 to 32).mkString(",")) + assert(slices(1).mkString(",") === (33 to 65).mkString(",")) + assert(slices(2).mkString(",") === (66 to 99).mkString(",")) + } + + test("splitting inclusive range") { + val data = 0 to 100 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === (0 to 32).mkString(",")) + assert(slices(1).mkString(",") === (33 to 66).mkString(",")) + assert(slices(2).mkString(",") === (67 to 100).mkString(",")) + } + + test("empty data") { + val data = new Array[Int](0) + val slices = ParallelCollectionRDD.slice(data, 5) + assert(slices.size === 5) + for (slice <- slices) assert(slice.size === 0) + } + + test("zero slices") { + val data = Array(1, 2, 3) + intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) } + } + + test("negative number of slices") { + val data = Array(1, 2, 3) + intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) } + } + + test("exclusive ranges sliced into ranges") { + val data = 1 until 100 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.forall(_.isInstanceOf[Range])) + } + + test("inclusive ranges sliced into ranges") { + val data = 1 to 100 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.forall(_.isInstanceOf[Range])) + } + + test("large ranges don't overflow") { + val N = 100 * 1000 * 1000 + val data = 0 until N + val slices = ParallelCollectionRDD.slice(data, 40) + assert(slices.size === 40) + for (i <- 0 until 40) { + assert(slices(i).isInstanceOf[Range]) + val range = slices(i).asInstanceOf[Range] + assert(range.start === i * (N / 40), "slice " + i + " start") + assert(range.end === (i+1) * (N / 40), "slice " + i + " end") + assert(range.step === 1, "slice " + i + " step") + } + } + + test("random array tests") { + val gen = for { + d <- arbitrary[List[Int]] + n <- Gen.choose(1, 100) + } yield (d, n) + val prop = forAll(gen) { + (tuple: (List[Int], Int)) => + val d = tuple._1 + val n = tuple._2 + val slices = ParallelCollectionRDD.slice(d, n) + ("n slices" |: slices.size == n) && + ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && + ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + } + check(prop) + } + + test("random exclusive range tests") { + val gen = for { + a <- Gen.choose(-100, 100) + b <- Gen.choose(-100, 100) + step <- Gen.choose(-5, 5) suchThat (_ != 0) + n <- Gen.choose(1, 100) + } yield (a until b by step, n) + val prop = forAll(gen) { + case (d: Range, n: Int) => + val slices = ParallelCollectionRDD.slice(d, n) + ("n slices" |: slices.size == n) && + ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && + ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && + ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + } + check(prop) + } + + test("random inclusive range tests") { + val gen = for { + a <- Gen.choose(-100, 100) + b <- Gen.choose(-100, 100) + step <- Gen.choose(-5, 5) suchThat (_ != 0) + n <- Gen.choose(1, 100) + } yield (a to b by step, n) + val prop = forAll(gen) { + case (d: Range, n: Int) => + val slices = ParallelCollectionRDD.slice(d, n) + ("n slices" |: slices.size == n) && + ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && + ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && + ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + } + check(prop) + } + + test("exclusive ranges of longs") { + val data = 1L until 100L + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } + + test("inclusive ranges of longs") { + val data = 1L to 100L + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } + + test("exclusive ranges of doubles") { + val data = 1.0 until 100.0 by 1.0 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } + + test("inclusive ranges of doubles") { + val data = 1.0 to 100.0 by 1.0 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala new file mode 100644 index 0000000000..94df282b28 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable.{Map, HashMap} + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import org.apache.spark.LocalSparkContext +import org.apache.spark.MapOutputTracker +import org.apache.spark.RDD +import org.apache.spark.SparkContext +import org.apache.spark.Partition +import org.apache.spark.TaskContext +import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency} +import org.apache.spark.{FetchFailed, Success, TaskEndReason} +import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} + +import org.apache.spark.scheduler.cluster.Pool +import org.apache.spark.scheduler.cluster.SchedulingMode +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode + +/** + * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler + * rather than spawning an event loop thread as happens in the real code. They use EasyMock + * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are + * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead + * host notifications are sent). In addition, tests may check for side effects on a non-mocked + * MapOutputTracker instance. + * + * Tests primarily consist of running DAGScheduler#processEvent and + * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet) + * and capturing the resulting TaskSets from the mock TaskScheduler. + */ +class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + /** Set of TaskSets the DAGScheduler has requested executed. */ + val taskSets = scala.collection.mutable.Buffer[TaskSet]() + val taskScheduler = new TaskScheduler() { + override def rootPool: Pool = null + override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def start() = {} + override def stop() = {} + override def submitTasks(taskSet: TaskSet) = { + // normally done by TaskSetManager + taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) + taskSets += taskSet + } + override def setListener(listener: TaskSchedulerListener) = {} + override def defaultParallelism() = 2 + } + + var mapOutputTracker: MapOutputTracker = null + var scheduler: DAGScheduler = null + + /** + * Set of cache locations to return from our mock BlockManagerMaster. + * Keys are (rdd ID, partition ID). Anything not present will return an empty + * list of cache locations silently. + */ + val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] + // stub out BlockManagerMaster.getLocations to use our cacheLocations + val blockManagerMaster = new BlockManagerMaster(null) { + override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + blockIds.map { name => + val pieces = name.split("_") + if (pieces(0) == "rdd") { + val key = pieces(1).toInt -> pieces(2).toInt + cacheLocations.getOrElse(key, Seq()) + } else { + Seq() + } + }.toSeq + } + override def removeExecutor(execId: String) { + // don't need to propagate to the driver, which we don't have + } + } + + /** The list of results that DAGScheduler has collected. */ + val results = new HashMap[Int, Any]() + var failure: Exception = _ + val listener = new JobListener() { + override def taskSucceeded(index: Int, result: Any) = results.put(index, result) + override def jobFailed(exception: Exception) = { failure = exception } + } + + before { + sc = new SparkContext("local", "DAGSchedulerSuite") + taskSets.clear() + cacheLocations.clear() + results.clear() + mapOutputTracker = new MapOutputTracker() + scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) { + override def runLocally(job: ActiveJob) { + // don't bother with the thread while unit testing + runLocallyWithinThread(job) + } + } + } + + after { + scheduler.stop() + } + + /** + * Type of RDD we use for testing. Note that we should never call the real RDD compute methods. + * This is a pair RDD type so it can always be used in ShuffleDependencies. + */ + type MyRDD = RDD[(Int, Int)] + + /** + * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and + * preferredLocations (if any) that are passed to them. They are deliberately not executable + * so we can test that DAGScheduler does not try to execute RDDs locally. + */ + private def makeRdd( + numPartitions: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil + ): MyRDD = { + val maxPartition = numPartitions - 1 + return new MyRDD(sc, dependencies) { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getPartitions = (0 to maxPartition).map(i => new Partition { + override def index = i + }).toArray + override def getPreferredLocations(split: Partition): Seq[String] = + if (locations.isDefinedAt(split.index)) + locations(split.index) + else + Nil + override def toString: String = "DAGSchedulerSuiteRDD " + id + } + } + + /** + * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting + * the scheduler not to exit. + * + * After processing the event, submit waiting stages as is done on most iterations of the + * DAGScheduler event loop. + */ + private def runEvent(event: DAGSchedulerEvent) { + assert(!scheduler.processEvent(event)) + scheduler.submitWaitingStages() + } + + /** + * When we submit dummy Jobs, this is the compute function we supply. Except in a local test + * below, we do not expect this function to ever be executed; instead, we will return results + * directly through CompletionEvents. + */ + private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) => + it.next.asInstanceOf[Tuple2[_, _]]._1 + + /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ + private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { + assert(taskSet.tasks.size >= results.size) + for ((result, i) <- results.zipWithIndex) { + if (i < taskSet.tasks.size) { + runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null)) + } + } + } + + /** Sends the rdd to the scheduler for scheduling. */ + private def submit( + rdd: RDD[_], + partitions: Array[Int], + func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, + allowLocal: Boolean = false, + listener: JobListener = listener) { + runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener)) + } + + /** Sends TaskSetFailed to the scheduler. */ + private def failed(taskSet: TaskSet, message: String) { + runEvent(TaskSetFailed(taskSet, message)) + } + + test("zero split job") { + val rdd = makeRdd(0, Nil) + var numResults = 0 + val fakeListener = new JobListener() { + override def taskSucceeded(partition: Int, value: Any) = numResults += 1 + override def jobFailed(exception: Exception) = throw exception + } + submit(rdd, Array(), listener = fakeListener) + assert(numResults === 0) + } + + test("run trivial job") { + val rdd = makeRdd(1, Nil) + submit(rdd, Array(0)) + complete(taskSets(0), List((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("local job") { + val rdd = new MyRDD(sc, Nil) { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + Array(42 -> 0).iterator + override def getPartitions = Array( new Partition { override def index = 0 } ) + override def getPreferredLocations(split: Partition) = Nil + override def toString = "DAGSchedulerSuite Local RDD" + } + runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) + assert(results === Map(0 -> 42)) + } + + test("run trivial job w/ dependency") { + val baseRdd = makeRdd(1, Nil) + val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) + submit(finalRdd, Array(0)) + complete(taskSets(0), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("cache location preferences w/ dependency") { + val baseRdd = makeRdd(1, Nil) + val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) + cacheLocations(baseRdd.id -> 0) = + Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) + submit(finalRdd, Array(0)) + val taskSet = taskSets(0) + assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) + complete(taskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("trivial job failure") { + submit(makeRdd(1, Nil), Array(0)) + failed(taskSets(0), "some failure") + assert(failure.getMessage === "Job failed: some failure") + } + + test("run trivial shuffle") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + complete(taskSets(1), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("run trivial shuffle with fetch failure") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + // the 2nd ResultTask failed + complete(taskSets(1), Seq( + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null))) + // this will get called + // blockManagerMaster.removeExecutor("exec-hostA") + // ask the scheduler to try it again + scheduler.resubmitFailedStages() + // have the 2nd attempt pass + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + // we can see both result blocks now + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + } + + test("ignore late map task completions") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + // pretend we were told hostA went away + val oldEpoch = mapOutputTracker.getEpoch + runEvent(ExecutorLost("exec-hostA")) + val newEpoch = mapOutputTracker.getEpoch + assert(newEpoch > oldEpoch) + val noAccum = Map[Long, Any]() + val taskSet = taskSets(0) + // should be ignored for being too old + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) + // should work because it's a non-failed host + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null)) + // should be ignored for being too old + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) + // should work because it's a new epoch + taskSet.tasks(1).epoch = newEpoch + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null)) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + complete(taskSets(1), Seq((Success, 42), (Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + } + + test("run trivial shuffle with out-of-band failure and retry") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + // blockManagerMaster.removeExecutor("exec-hostA") + // pretend we were told hostA went away + runEvent(ExecutorLost("exec-hostA")) + // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks + // rather than marking it is as failed and waiting. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + // have hostC complete the resubmitted task + complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("recursive shuffle failures") { + val shuffleOneRdd = makeRdd(2, Nil) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = makeRdd(1, List(shuffleDepTwo)) + submit(finalRdd, Array(0)) + // have the first stage complete normally + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + // have the second stage complete normally + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)))) + // fail the third stage because hostA went down + complete(taskSets(2), Seq( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) + // TODO assert this: + // blockManagerMaster.removeExecutor("exec-hostA") + // have DAGScheduler try again + scheduler.resubmitFailedStages() + complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2)))) + complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) + complete(taskSets(5), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("cached post-shuffle") { + val shuffleOneRdd = makeRdd(2, Nil) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = makeRdd(1, List(shuffleDepTwo)) + submit(finalRdd, Array(0)) + cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) + cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) + // complete stage 2 + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + // complete stage 1 + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + // pretend stage 0 failed because hostA went down + complete(taskSets(2), Seq( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) + // TODO assert this: + // blockManagerMaster.removeExecutor("exec-hostA") + // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. + scheduler.resubmitFailedStages() + assertLocations(taskSets(3), Seq(Seq("hostD"))) + // allow hostD to recover + complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) + complete(taskSets(4), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + /** + * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. + * Note that this checks only the host and not the executor ID. + */ + private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) { + assert(hosts.size === taskSet.tasks.size) + for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) { + assert(taskLocs.map(_.host) === expectedLocs) + } + } + + private def makeMapStatus(host: String, reduces: Int): MapStatus = + new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) + + private def makeBlockManagerId(host: String): BlockManagerId = + BlockManagerId("exec-" + host, host, 12345, 0) + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala new file mode 100644 index 0000000000..f5b3e97222 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties +import java.util.concurrent.LinkedBlockingQueue +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import scala.collection.mutable +import org.apache.spark._ +import org.apache.spark.SparkContext._ + + +class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + + test("inner method") { + sc = new SparkContext("local", "joblogger") + val joblogger = new JobLogger { + def createLogWriterTest(jobID: Int) = createLogWriter(jobID) + def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID) + def getRddNameTest(rdd: RDD[_]) = getRddName(rdd) + def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) + } + type MyRDD = RDD[(Int, Int)] + def makeRdd( + numPartitions: Int, + dependencies: List[Dependency[_]] + ): MyRDD = { + val maxPartition = numPartitions - 1 + return new MyRDD(sc, dependencies) { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getPartitions = (0 to maxPartition).map(i => new Partition { + override def index = i + }).toArray + } + } + val jobID = 5 + val parentRdd = makeRdd(4, Nil) + val shuffleDep = new ShuffleDependency(parentRdd, null) + val rootRdd = makeRdd(4, List(shuffleDep)) + val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None) + val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None) + + joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null)) + joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) + parentRdd.setName("MyRDD") + joblogger.getRddNameTest(parentRdd) should be ("MyRDD") + joblogger.createLogWriterTest(jobID) + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.buildJobDepTest(jobID, rootStage) + joblogger.getJobIDToStages.get(jobID).get.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(jobID)) + joblogger.getStageIDToJobID.get(1) should be (Some(jobID)) + joblogger.closeLogWriterTest(jobID) + joblogger.getStageIDToJobID.size should be (0) + joblogger.getJobIDToStages.size should be (0) + joblogger.getJobIDtoPrintWriter.size should be (0) + } + + test("inner variables") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + override protected def closeLogWriter(jobID: Int) = + getJobIDtoPrintWriter.get(jobID).foreach { fileWriter => + fileWriter.close() + } + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.getLogDir should be ("/tmp/spark") + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.getStageIDToJobID.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(0)) + joblogger.getStageIDToJobID.get(1) should be (Some(0)) + joblogger.getJobIDToStages.size should be (1) + } + + + test("interface functions") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + var onTaskEndCount = 0 + var onJobEndCount = 0 + var onJobStartCount = 0 + var onStageCompletedCount = 0 + var onStageSubmittedCount = 0 + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 + override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1 + override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 + override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.onJobStartCount should be (1) + joblogger.onJobEndCount should be (1) + joblogger.onTaskEndCount should be (8) + joblogger.onStageSubmittedCount should be (2) + joblogger.onStageCompletedCount should be (2) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala new file mode 100644 index 0000000000..aac7c207cb --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.scalatest.FunSuite +import org.apache.spark.{SparkContext, LocalSparkContext} +import scala.collection.mutable +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.SparkContext._ + +/** + * + */ + +class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + + test("local metrics") { + sc = new SparkContext("local[4]", "test") + val listener = new SaveStageInfo + sc.addSparkListener(listener) + sc.addSparkListener(new StatsReportListener) + //just to make sure some of the tasks take a noticeable amount of time + val w = {i:Int => + if (i == 0) + Thread.sleep(100) + i + } + + val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)} + d.count + listener.stageInfos.size should be (1) + + val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1") + + val d3 = d.map{i => w(i) -> (0 to (i % 5))}.setName("shuffle input 2") + + val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)} + d4.setName("A Cogroup") + + d4.collectAsMap + + listener.stageInfos.size should be (4) + listener.stageInfos.foreach {stageInfo => + //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms + checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration") + checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime") + checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime") + if (stageInfo.stage.rdd.name == d4.name) { + checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime") + } + + stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) => + taskMetrics.resultSize should be > (0l) + if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) { + taskMetrics.shuffleWriteMetrics should be ('defined) + taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l) + } + if (stageInfo.stage.rdd.name == d4.name) { + taskMetrics.shuffleReadMetrics should be ('defined) + val sm = taskMetrics.shuffleReadMetrics.get + sm.totalBlocksFetched should be > (0) + sm.localBlocksFetched should be > (0) + sm.remoteBlocksFetched should be (0) + sm.remoteBytesRead should be (0l) + sm.remoteFetchTime should be (0l) + } + } + } + } + + def checkNonZeroAvg(m: Traversable[Long], msg: String) { + assert(m.sum / m.size.toDouble > 0.0, msg) + } + + def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = { + val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name} + !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty + } + + class SaveStageInfo extends SparkListener { + val stageInfos = mutable.Buffer[StageInfo]() + override def onStageCompleted(stage: StageCompleted) { + stageInfos += stage.stageInfo + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala new file mode 100644 index 0000000000..0347cc02d7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.apache.spark.TaskContext +import org.apache.spark.RDD +import org.apache.spark.SparkContext +import org.apache.spark.Partition +import org.apache.spark.LocalSparkContext + +class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + test("Calls executeOnCompleteCallbacks after failure") { + var completed = false + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc, List()) { + override def getPartitions = Array[Partition](StubPartition(0)) + override def compute(split: Partition, context: TaskContext) = { + context.addOnCompleteCallback(() => completed = true) + sys.error("failed") + } + } + val func = (c: TaskContext, i: Iterator[String]) => i.next + val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) + intercept[RuntimeException] { + task.run(0) + } + assert(completed === true) + } + + case class StubPartition(val index: Int) extends Partition +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala new file mode 100644 index 0000000000..92ad9f09b2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import org.apache.spark._ +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster._ +import scala.collection.mutable.ArrayBuffer + +import java.util.Properties + +class FakeTaskSetManager( + initPriority: Int, + initStageId: Int, + initNumTasks: Int, + clusterScheduler: ClusterScheduler, + taskSet: TaskSet) + extends ClusterTaskSetManager(clusterScheduler, taskSet) { + + parent = null + weight = 1 + minShare = 2 + runningTasks = 0 + priority = initPriority + stageId = initStageId + name = "TaskSet_"+stageId + override val numTasks = initNumTasks + tasksFinished = 0 + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + override def addSchedulable(schedulable: Schedulable) { + } + + override def removeSchedulable(schedulable: Schedulable) { + } + + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def executorLost(executorId: String, host: String): Unit = { + } + + override def resourceOffer( + execId: String, + host: String, + availableCpus: Int, + maxLocality: TaskLocality.TaskLocality) + : Option[TaskDescription] = + { + if (tasksFinished + runningTasks < numTasks) { + increaseRunningTasks(1) + return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) + } + return None + } + + override def checkSpeculatableTasks(): Boolean = { + return true + } + + def taskFinished() { + decreaseRunningTasks(1) + tasksFinished +=1 + if (tasksFinished == numTasks) { + parent.removeSchedulable(this) + } + } + + def abort() { + decreaseRunningTasks(runningTasks) + parent.removeSchedulable(this) + } +} + +class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { + + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { + new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) + } + + def resourceOffer(rootPool: Pool): Int = { + val taskSetQueue = rootPool.getSortedTaskSetQueue() + /* Just for Test*/ + for (manager <- taskSetQueue) { + logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) + } + for (taskSet <- taskSetQueue) { + taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match { + case Some(task) => + return taskSet.stageId + case None => {} + } + } + -1 + } + + def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { + assert(resourceOffer(rootPool) === expectedTaskSetId) + } + + test("FIFO Scheduler Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet) + val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet) + val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + schedulableBuilder.addTaskSetManager(taskSetManager1, null) + schedulableBuilder.addTaskSetManager(taskSetManager2, null) + + checkTaskSetId(rootPool, 0) + resourceOffer(rootPool) + checkTaskSetId(rootPool, 1) + resourceOffer(rootPool) + taskSetManager1.abort() + checkTaskSetId(rootPool, 2) + } + + test("Fair Scheduler Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.fairscheduler.allocation.file", xmlPath) + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName("default") != null) + assert(rootPool.getSchedulableByName("1") != null) + assert(rootPool.getSchedulableByName("2") != null) + assert(rootPool.getSchedulableByName("3") != null) + assert(rootPool.getSchedulableByName("1").minShare === 2) + assert(rootPool.getSchedulableByName("1").weight === 1) + assert(rootPool.getSchedulableByName("2").minShare === 3) + assert(rootPool.getSchedulableByName("2").weight === 1) + assert(rootPool.getSchedulableByName("3").minShare === 2) + assert(rootPool.getSchedulableByName("3").weight === 1) + + val properties1 = new Properties() + properties1.setProperty("spark.scheduler.cluster.fair.pool","1") + val properties2 = new Properties() + properties2.setProperty("spark.scheduler.cluster.fair.pool","2") + + val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet) + val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet) + val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) + + val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet) + val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) + schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 1) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 4) + + taskSetManager12.taskFinished() + assert(rootPool.getSchedulableByName("1").runningTasks === 3) + taskSetManager24.abort() + assert(rootPool.getSchedulableByName("2").runningTasks === 2) + } + + test("Nested Pool Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) + val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) + rootPool.addSchedulable(pool0) + rootPool.addSchedulable(pool1) + + val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) + val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) + pool0.addSchedulable(pool00) + pool0.addSchedulable(pool01) + + val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) + val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) + pool1.addSchedulable(pool10) + pool1.addSchedulable(pool11) + + val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet) + val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet) + pool00.addSchedulable(taskSetManager000) + pool00.addSchedulable(taskSetManager001) + + val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet) + val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet) + pool01.addSchedulable(taskSetManager010) + pool01.addSchedulable(taskSetManager011) + + val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet) + val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet) + pool10.addSchedulable(taskSetManager100) + pool10.addSchedulable(taskSetManager101) + + val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet) + val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet) + pool11.addSchedulable(taskSetManager110) + pool11.addSchedulable(taskSetManager111) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 6) + checkTaskSetId(rootPool, 2) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala new file mode 100644 index 0000000000..a4f63baf3d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark._ +import org.apache.spark.scheduler._ +import org.apache.spark.executor.TaskMetrics +import java.nio.ByteBuffer +import org.apache.spark.util.FakeClock + +/** + * A mock ClusterScheduler implementation that just remembers information about tasks started and + * feedback received from the TaskSetManagers. Note that it's important to initialize this with + * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost + * to work, and these are required for locality in ClusterTaskSetManager. + */ +class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) + extends ClusterScheduler(sc) +{ + val startedTasks = new ArrayBuffer[Long] + val endedTasks = new mutable.HashMap[Long, TaskEndReason] + val finishedManagers = new ArrayBuffer[TaskSetManager] + + val executors = new mutable.HashMap[String, String] ++ liveExecutors + + listener = new TaskSchedulerListener { + def taskStarted(task: Task[_], taskInfo: TaskInfo) { + startedTasks += taskInfo.index + } + + def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: mutable.Map[Long, Any], + taskInfo: TaskInfo, + taskMetrics: TaskMetrics) + { + endedTasks(taskInfo.index) = reason + } + + def executorGained(execId: String, host: String) {} + + def executorLost(execId: String) {} + + def taskSetFailed(taskSet: TaskSet, reason: String) {} + } + + def removeExecutor(execId: String): Unit = executors -= execId + + override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager + + override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) + + override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) +} + +class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { + import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} + + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + test("TaskSet with no preferences") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(1) + val manager = new ClusterTaskSetManager(sched, taskSet) + + // Offer a host with no CPUs + assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) + + // Offer a host with process-local as the constraint; this should work because the TaskSet + // above won't have any locality preferences + val taskOption = manager.resourceOffer("exec1", "host1", 2, TaskLocality.PROCESS_LOCAL) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === "exec1") + assert(sched.startedTasks.contains(0)) + + // Re-offer the host -- now we should get no more tasks + assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None) + + // Tell it the task has finished + manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0)) + assert(sched.endedTasks(0) === Success) + assert(sched.finishedManagers.contains(manager)) + } + + test("multiple offers with no preferences") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(3) + val manager = new ClusterTaskSetManager(sched, taskSet) + + // First three offers should all find tasks + for (i <- 0 until 3) { + val taskOption = manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === "exec1") + } + assert(sched.startedTasks.toSet === Set(0, 1, 2)) + + // Re-offer the host -- now we should get no more tasks + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + // Finish the first two tasks + manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0)) + manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1)) + assert(sched.endedTasks(0) === Success) + assert(sched.endedTasks(1) === Success) + assert(!sched.finishedManagers.contains(manager)) + + // Finish the last task + manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2)) + assert(sched.endedTasks(2) === Success) + assert(sched.finishedManagers.contains(manager)) + } + + test("basic delay scheduling") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = createTaskSet(4, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host2", "exec2")), + Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), + Seq() // Last task has no locality prefs + ) + val clock = new FakeClock + val manager = new ClusterTaskSetManager(sched, taskSet, clock) + + // First offer host1, exec1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1, exec1 again: the last task, which has no prefs, should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 3) + + // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2 + assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL).get.index == 2) + + // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL) === None) + + // Offer host1, exec1 again, at ANY level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1, exec1 again, at ANY level: task 1 should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + } + + test("delay scheduling with fallback") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, + ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) + val taskSet = createTaskSet(5, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host3")), + Seq(TaskLocation("host2")) + ) + val clock = new FakeClock + val manager = new ClusterTaskSetManager(sched, taskSet, clock) + + // First offer host1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1 again: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1 again: second task (on host2) should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Offer host1 again: third task (on host2) should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) + + // Offer host2: fifth task (also on host2) should get chosen + assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 4) + + // Now that we've launched a local task, we should no longer launch the task for host3 + assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // After another delay, we can go ahead and launch that task non-locally + assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 3) + } + + test("delay scheduling with failed hosts") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = createTaskSet(3, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host3")) + ) + val clock = new FakeClock + val manager = new ClusterTaskSetManager(sched, taskSet, clock) + + // First offer host1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1 again: third task should be chosen immediately because host3 is not up + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) + + // After this, nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + // Now mark host2 as dead + sched.removeExecutor("exec2") + manager.executorLost("exec2", "host2") + + // Task 1 should immediately be launched on host1 because its original host is gone + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Now that all tasks have launched, nothing new should be launched anywhere else + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) + } + + /** + * Utility method to create a TaskSet, potentially setting a particular sequence of preferred + * locations for each task (given as varargs) if this sequence is not empty. + */ + def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) + } + new TaskSet(tasks, 0, 0, 0, null) + } + + def createTaskResult(id: Int): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics))) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala new file mode 100644 index 0000000000..2f12aaed18 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark.scheduler.{TaskLocation, Task} + +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) { + override def run(attemptId: Long): Int = 0 + + override def preferredLocations: Seq[TaskLocation] = prefLocs +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala new file mode 100644 index 0000000000..111340a65c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.local + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import org.apache.spark._ +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ConcurrentMap, HashMap} +import java.util.concurrent.Semaphore +import java.util.concurrent.CountDownLatch +import java.util.Properties + +class Lock() { + var finished = false + def jobWait() = { + synchronized { + while(!finished) { + this.wait() + } + } + } + + def jobFinished() = { + synchronized { + finished = true + this.notifyAll() + } + } +} + +object TaskThreadInfo { + val threadToLock = HashMap[Int, Lock]() + val threadToRunning = HashMap[Int, Boolean]() + val threadToStarted = HashMap[Int, CountDownLatch]() +} + +/* + * 1. each thread contains one job. + * 2. each job contains one stage. + * 3. each stage only contains one task. + * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure + * it will get cpu core resource, and will wait to finished after user manually + * release "Lock" and then cluster will contain another free cpu cores. + * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, + * thus it will be scheduled later when cluster has free cpu cores. + */ +class LocalSchedulerSuite extends FunSuite with LocalSparkContext { + + def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { + + TaskThreadInfo.threadToRunning(threadIndex) = false + val nums = sc.parallelize(threadIndex to threadIndex, 1) + TaskThreadInfo.threadToLock(threadIndex) = new Lock() + TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) + new Thread { + if (poolName != null) { + sc.setLocalProperty("spark.scheduler.cluster.fair.pool", poolName) + } + override def run() { + val ans = nums.map(number => { + TaskThreadInfo.threadToRunning(number) = true + TaskThreadInfo.threadToStarted(number).countDown() + TaskThreadInfo.threadToLock(number).jobWait() + TaskThreadInfo.threadToRunning(number) = false + number + }).collect() + assert(ans.toList === List(threadIndex)) + sem.release() + } + }.start() + } + + test("Local FIFO scheduler end-to-end test") { + System.setProperty("spark.cluster.schedulingmode", "FIFO") + sc = new SparkContext("local[4]", "test") + val sem = new Semaphore(0) + + createThread(1,null,sc,sem) + TaskThreadInfo.threadToStarted(1).await() + createThread(2,null,sc,sem) + TaskThreadInfo.threadToStarted(2).await() + createThread(3,null,sc,sem) + TaskThreadInfo.threadToStarted(3).await() + createThread(4,null,sc,sem) + TaskThreadInfo.threadToStarted(4).await() + // thread 5 and 6 (stage pending)must meet following two points + // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager + // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() + // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 + // So I just use "sleep" 1s here for each thread. + // TODO: any better solution? + createThread(5,null,sc,sem) + Thread.sleep(1000) + createThread(6,null,sc,sem) + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(1) === true) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === false) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(1).jobFinished() + TaskThreadInfo.threadToStarted(5).await() + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(3).jobFinished() + TaskThreadInfo.threadToStarted(6).await() + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === false) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === true) + + TaskThreadInfo.threadToLock(2).jobFinished() + TaskThreadInfo.threadToLock(4).jobFinished() + TaskThreadInfo.threadToLock(5).jobFinished() + TaskThreadInfo.threadToLock(6).jobFinished() + sem.acquire(6) + } + + test("Local fair scheduler end-to-end test") { + sc = new SparkContext("local[8]", "LocalSchedulerSuite") + val sem = new Semaphore(0) + System.setProperty("spark.cluster.schedulingmode", "FAIR") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.fairscheduler.allocation.file", xmlPath) + + createThread(10,"1",sc,sem) + TaskThreadInfo.threadToStarted(10).await() + createThread(20,"2",sc,sem) + TaskThreadInfo.threadToStarted(20).await() + createThread(30,"3",sc,sem) + TaskThreadInfo.threadToStarted(30).await() + + assert(TaskThreadInfo.threadToRunning(10) === true) + assert(TaskThreadInfo.threadToRunning(20) === true) + assert(TaskThreadInfo.threadToRunning(30) === true) + + createThread(11,"1",sc,sem) + TaskThreadInfo.threadToStarted(11).await() + createThread(21,"2",sc,sem) + TaskThreadInfo.threadToStarted(21).await() + createThread(31,"3",sc,sem) + TaskThreadInfo.threadToStarted(31).await() + + assert(TaskThreadInfo.threadToRunning(11) === true) + assert(TaskThreadInfo.threadToRunning(21) === true) + assert(TaskThreadInfo.threadToRunning(31) === true) + + createThread(12,"1",sc,sem) + TaskThreadInfo.threadToStarted(12).await() + createThread(22,"2",sc,sem) + TaskThreadInfo.threadToStarted(22).await() + createThread(32,"3",sc,sem) + + assert(TaskThreadInfo.threadToRunning(12) === true) + assert(TaskThreadInfo.threadToRunning(22) === true) + assert(TaskThreadInfo.threadToRunning(32) === false) + + TaskThreadInfo.threadToLock(10).jobFinished() + TaskThreadInfo.threadToStarted(32).await() + + assert(TaskThreadInfo.threadToRunning(32) === true) + + //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager + // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. + //2. priority of 23 and 33 will be meaningless as using fair scheduler here. + createThread(23,"2",sc,sem) + createThread(33,"3",sc,sem) + Thread.sleep(1000) + + TaskThreadInfo.threadToLock(11).jobFinished() + TaskThreadInfo.threadToStarted(23).await() + + assert(TaskThreadInfo.threadToRunning(23) === true) + assert(TaskThreadInfo.threadToRunning(33) === false) + + TaskThreadInfo.threadToLock(12).jobFinished() + TaskThreadInfo.threadToStarted(33).await() + + assert(TaskThreadInfo.threadToRunning(33) === true) + + TaskThreadInfo.threadToLock(20).jobFinished() + TaskThreadInfo.threadToLock(21).jobFinished() + TaskThreadInfo.threadToLock(22).jobFinished() + TaskThreadInfo.threadToLock(23).jobFinished() + TaskThreadInfo.threadToLock(30).jobFinished() + TaskThreadInfo.threadToLock(31).jobFinished() + TaskThreadInfo.threadToLock(32).jobFinished() + TaskThreadInfo.threadToLock(33).jobFinished() + + sem.acquire(11) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala new file mode 100644 index 0000000000..88ba10f2f2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -0,0 +1,666 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import akka.actor._ + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.PrivateMethodTester +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.matchers.ShouldMatchers._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.JavaSerializer +import org.apache.spark.KryoSerializer +import org.apache.spark.SizeEstimator +import org.apache.spark.Utils +import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.ByteBufferInputStream + + +class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { + var store: BlockManager = null + var store2: BlockManager = null + var actorSystem: ActorSystem = null + var master: BlockManagerMaster = null + var oldArch: String = null + var oldOops: String = null + var oldHeartBeat: String = null + + // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + System.setProperty("spark.kryoserializer.buffer.mb", "1") + val serializer = new KryoSerializer + + before { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) + this.actorSystem = actorSystem + System.setProperty("spark.driver.port", boundPort.toString) + System.setProperty("spark.hostPort", "localhost:" + boundPort) + + master = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) + + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + oldArch = System.setProperty("os.arch", "amd64") + oldOops = System.setProperty("spark.test.useCompressedOops", "true") + oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true") + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + // Set some value ... + System.setProperty("spark.hostPort", Utils.localHostName() + ":" + 1111) + } + + after { + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + + if (store != null) { + store.stop() + store = null + } + if (store2 != null) { + store2.stop() + store2 = null + } + actorSystem.shutdown() + actorSystem.awaitTermination() + actorSystem = null + master = null + + if (oldArch != null) { + System.setProperty("os.arch", oldArch) + } else { + System.clearProperty("os.arch") + } + + if (oldOops != null) { + System.setProperty("spark.test.useCompressedOops", oldOops) + } else { + System.clearProperty("spark.test.useCompressedOops") + } + } + + test("StorageLevel object caching") { + val level1 = StorageLevel(false, false, false, 3) + val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1 + val level3 = StorageLevel(false, false, false, 2) // this should return a different object + assert(level2 === level1, "level2 is not same as level1") + assert(level2.eq(level1), "level2 is not the same object as level1") + assert(level3 != level1, "level3 is same as level1") + val bytes1 = Utils.serialize(level1) + val level1_ = Utils.deserialize[StorageLevel](bytes1) + val bytes2 = Utils.serialize(level2) + val level2_ = Utils.deserialize[StorageLevel](bytes2) + assert(level1_ === level1, "Deserialized level1 not same as original level1") + assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2") + assert(level2_ === level2, "Deserialized level2 not same as original level2") + assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1") + } + + test("BlockManagerId object caching") { + val id1 = BlockManagerId("e1", "XXX", 1, 0) + val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1 + val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object + assert(id2 === id1, "id2 is not same as id1") + assert(id2.eq(id1), "id2 is not the same object as id1") + assert(id3 != id1, "id3 is same as id1") + val bytes1 = Utils.serialize(id1) + val id1_ = Utils.deserialize[BlockManagerId](bytes1) + val bytes2 = Utils.serialize(id2) + val id2_ = Utils.deserialize[BlockManagerId](bytes2) + assert(id1_ === id1, "Deserialized id1 is not same as original id1") + assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1") + assert(id2_ === id2, "Deserialized id2 is not same as original id2") + assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") + } + + test("master + 1 manager interaction") { + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + + // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) + + // Checking whether blocks are in memory + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + + // Checking whether master knows about the blocks or not + assert(master.getLocations("a1").size > 0, "master was not told about a1") + assert(master.getLocations("a2").size > 0, "master was not told about a2") + assert(master.getLocations("a3").size === 0, "master was told about a3") + + // Drop a1 and a2 from memory; this should be reported back to the master + store.dropFromMemory("a1", null) + store.dropFromMemory("a2", null) + assert(store.getSingle("a1") === None, "a1 not removed from store") + assert(store.getSingle("a2") === None, "a2 not removed from store") + assert(master.getLocations("a1").size === 0, "master did not remove a1") + assert(master.getLocations("a2").size === 0, "master did not remove a2") + } + + test("master + 2 managers interaction") { + store = new BlockManager("exec1", actorSystem, master, serializer, 2000) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000) + + val peers = master.getPeers(store.blockManagerId, 1) + assert(peers.size === 1, "master did not return the other manager as a peer") + assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager") + + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) + store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) + assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") + assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") + } + + test("removing block") { + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + + // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 + store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) + + // Checking whether blocks are in memory and memory size + val memStatus = master.getMemoryStatus.head._2 + assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") + assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200") + assert(store.getSingle("a1-to-remove") != None, "a1 was not in store") + assert(store.getSingle("a2-to-remove") != None, "a2 was not in store") + assert(store.getSingle("a3-to-remove") != None, "a3 was not in store") + + // Checking whether master knows about the blocks or not + assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1") + assert(master.getLocations("a2-to-remove").size > 0, "master was not told about a2") + assert(master.getLocations("a3-to-remove").size === 0, "master was told about a3") + + // Remove a1 and a2 and a3. Should be no-op for a3. + master.removeBlock("a1-to-remove") + master.removeBlock("a2-to-remove") + master.removeBlock("a3-to-remove") + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a1-to-remove") should be (None) + master.getLocations("a1-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a2-to-remove") should be (None) + master.getLocations("a2-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a3-to-remove") should not be (None) + master.getLocations("a3-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + val memStatus = master.getMemoryStatus.head._2 + memStatus._1 should equal (2000L) + memStatus._2 should equal (2000L) + } + } + + test("removing rdd") { + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + // Putting a1, a2 and a3 in memory. + store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) + master.removeRdd(0, blocking = false) + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("rdd_0_0") should be (None) + master.getLocations("rdd_0_0") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("rdd_0_1") should be (None) + master.getLocations("rdd_0_1") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("nonrddblock") should not be (None) + master.getLocations("nonrddblock") should have size (1) + } + + store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + master.removeRdd(0, blocking = true) + store.getSingle("rdd_0_0") should be (None) + master.getLocations("rdd_0_0") should have size 0 + store.getSingle("rdd_0_1") should be (None) + master.getLocations("rdd_0_1") should have size 0 + } + + test("reregistration on heart beat") { + val heartBeat = PrivateMethod[Unit]('heartBeat) + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(master.getLocations("a1").size > 0, "master was not told about a1") + + master.removeExecutor(store.blockManagerId.executorId) + assert(master.getLocations("a1").size == 0, "a1 was not removed from master") + + store invokePrivate heartBeat() + assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") + } + + test("reregistration on block update") { + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + assert(master.getLocations("a1").size > 0, "master was not told about a1") + + master.removeExecutor(store.blockManagerId.executorId) + assert(master.getLocations("a1").size == 0, "a1 was not removed from master") + + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) + store.waitForAsyncReregister() + + assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") + assert(master.getLocations("a2").size > 0, "master was not told about a2") + } + + test("reregistration doesn't dead lock") { + val heartBeat = PrivateMethod[Unit]('heartBeat) + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = List(new Array[Byte](400)) + + // try many times to trigger any deadlocks + for (i <- 1 to 100) { + master.removeExecutor(store.blockManagerId.executorId) + val t1 = new Thread { + override def run() { + store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } + } + val t2 = new Thread { + override def run() { + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + } + } + val t3 = new Thread { + override def run() { + store invokePrivate heartBeat() + } + } + + t1.start() + t2.start() + t3.start() + t1.join() + t2.join() + t3.join() + + store.dropFromMemory("a1", null) + store.dropFromMemory("a2", null) + store.waitForAsyncReregister() + } + } + + test("in-memory LRU storage") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY) + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.getSingle("a1") === None, "a1 was in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + // At this point a2 was gotten last, so LRU will getSingle rid of a3 + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") === None, "a3 was in store") + } + + test("in-memory LRU storage with serialization") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.getSingle("a1") === None, "a1 was in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + // At this point a2 was gotten last, so LRU will getSingle rid of a3 + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") === None, "a3 was in store") + } + + test("in-memory LRU for partitions of same RDD") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY) + // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 + // from the same RDD + assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") + assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") + assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store") + // Check that rdd_0_3 doesn't replace them even after further accesses + assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") + assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") + assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") + } + + test("in-memory LRU for partitions of multiple RDDs") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + // At this point rdd_1_1 should've replaced rdd_0_1 + assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store") + assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") + assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") + // Do a get() on rdd_0_2 so that it is the most recently used item + assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") + // Put in more partitions from RDD 0; they should replace rdd_1_1 + store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped + // when we try to add rdd_0_4. + assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store") + assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") + assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store") + assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") + assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store") + } + + test("on-disk storage") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.DISK_ONLY) + store.putSingle("a2", a2, StorageLevel.DISK_ONLY) + store.putSingle("a3", a3, StorageLevel.DISK_ONLY) + assert(store.getSingle("a2") != None, "a2 was in store") + assert(store.getSingle("a3") != None, "a3 was in store") + assert(store.getSingle("a1") != None, "a1 was in store") + } + + test("disk and memory storage") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) + store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) + store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") + } + + test("disk and memory storage with getLocalBytes") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) + store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) + store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) + assert(store.getLocalBytes("a2") != None, "a2 was not in store") + assert(store.getLocalBytes("a3") != None, "a3 was not in store") + assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") + assert(store.getLocalBytes("a1") != None, "a1 was not in store") + assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") + } + + test("disk and memory storage with serialization") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) + store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) + store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") + } + + test("disk and memory storage with serialization and getLocalBytes") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) + store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) + store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) + assert(store.getLocalBytes("a2") != None, "a2 was not in store") + assert(store.getLocalBytes("a3") != None, "a3 was not in store") + assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") + assert(store.getLocalBytes("a1") != None, "a1 was not in store") + assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") + } + + test("LRU with mixed storage levels") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + val a4 = new Array[Byte](400) + // First store a1 and a2, both in memory, and a3, on disk only + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a3", a3, StorageLevel.DISK_ONLY) + // At this point LRU should not kick in because a3 is only on disk + assert(store.getSingle("a1") != None, "a2 was not in store") + assert(store.getSingle("a2") != None, "a3 was not in store") + assert(store.getSingle("a3") != None, "a1 was not in store") + assert(store.getSingle("a1") != None, "a2 was not in store") + assert(store.getSingle("a2") != None, "a3 was not in store") + assert(store.getSingle("a3") != None, "a1 was not in store") + // Now let's add in a4, which uses both disk and memory; a1 should drop out + store.putSingle("a4", a4, StorageLevel.MEMORY_AND_DISK_SER) + assert(store.getSingle("a1") == None, "a1 was in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.getSingle("a4") != None, "a4 was not in store") + } + + test("in-memory LRU with streams") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val list1 = List(new Array[Byte](200), new Array[Byte](200)) + val list2 = List(new Array[Byte](200), new Array[Byte](200)) + val list3 = List(new Array[Byte](200), new Array[Byte](200)) + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.get("list2") != None, "list2 was not in store") + assert(store.get("list2").get.size == 2) + assert(store.get("list3") != None, "list3 was not in store") + assert(store.get("list3").get.size == 2) + assert(store.get("list1") === None, "list1 was in store") + assert(store.get("list2") != None, "list2 was not in store") + assert(store.get("list2").get.size == 2) + // At this point list2 was gotten last, so LRU will getSingle rid of list3 + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.get("list1") != None, "list1 was not in store") + assert(store.get("list1").get.size == 2) + assert(store.get("list2") != None, "list2 was not in store") + assert(store.get("list2").get.size == 2) + assert(store.get("list3") === None, "list1 was in store") + } + + test("LRU with mixed storage levels and streams") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val list1 = List(new Array[Byte](200), new Array[Byte](200)) + val list2 = List(new Array[Byte](200), new Array[Byte](200)) + val list3 = List(new Array[Byte](200), new Array[Byte](200)) + val list4 = List(new Array[Byte](200), new Array[Byte](200)) + // First store list1 and list2, both in memory, and list3, on disk only + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) + // At this point LRU should not kick in because list3 is only on disk + assert(store.get("list1") != None, "list2 was not in store") + assert(store.get("list1").get.size === 2) + assert(store.get("list2") != None, "list3 was not in store") + assert(store.get("list2").get.size === 2) + assert(store.get("list3") != None, "list1 was not in store") + assert(store.get("list3").get.size === 2) + assert(store.get("list1") != None, "list2 was not in store") + assert(store.get("list1").get.size === 2) + assert(store.get("list2") != None, "list3 was not in store") + assert(store.get("list2").get.size === 2) + assert(store.get("list3") != None, "list1 was not in store") + assert(store.get("list3").get.size === 2) + // Now let's add in list4, which uses both disk and memory; list1 should drop out + store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) + assert(store.get("list1") === None, "list1 was in store") + assert(store.get("list2") != None, "list3 was not in store") + assert(store.get("list2").get.size === 2) + assert(store.get("list3") != None, "list1 was not in store") + assert(store.get("list3").get.size === 2) + assert(store.get("list4") != None, "list4 was not in store") + assert(store.get("list4").get.size === 2) + } + + test("negative byte values in ByteBufferInputStream") { + val buffer = ByteBuffer.wrap(Array[Int](254, 255, 0, 1, 2).map(_.toByte).toArray) + val stream = new ByteBufferInputStream(buffer) + val temp = new Array[Byte](10) + assert(stream.read() === 254, "unexpected byte read") + assert(stream.read() === 255, "unexpected byte read") + assert(stream.read() === 0, "unexpected byte read") + assert(stream.read(temp, 0, temp.length) === 2, "unexpected number of bytes read") + assert(stream.read() === -1, "end of stream not signalled") + assert(stream.read(temp, 0, temp.length) === -1, "end of stream not signalled") + } + + test("overly large block") { + store = new BlockManager("", actorSystem, master, serializer, 500) + store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) + assert(store.getSingle("a1") === None, "a1 was in store") + store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) + assert(store.memoryStore.getValues("a2") === None, "a2 was in memory store") + assert(store.getSingle("a2") != None, "a2 was not in store") + } + + test("block compression") { + try { + System.setProperty("spark.shuffle.compress", "true") + store = new BlockManager("exec1", actorSystem, master, serializer, 2000) + store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") + store.stop() + store = null + + System.setProperty("spark.shuffle.compress", "false") + store = new BlockManager("exec2", actorSystem, master, serializer, 2000) + store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") + store.stop() + store = null + + System.setProperty("spark.broadcast.compress", "true") + store = new BlockManager("exec3", actorSystem, master, serializer, 2000) + store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") + store.stop() + store = null + + System.setProperty("spark.broadcast.compress", "false") + store = new BlockManager("exec4", actorSystem, master, serializer, 2000) + store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") + store.stop() + store = null + + System.setProperty("spark.rdd.compress", "true") + store = new BlockManager("exec5", actorSystem, master, serializer, 2000) + store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") + store.stop() + store = null + + System.setProperty("spark.rdd.compress", "false") + store = new BlockManager("exec6", actorSystem, master, serializer, 2000) + store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") + store.stop() + store = null + + // Check that any other block types are also kept uncompressed + store = new BlockManager("exec7", actorSystem, master, serializer, 2000) + store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) + assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") + store.stop() + store = null + } finally { + System.clearProperty("spark.shuffle.compress") + System.clearProperty("spark.broadcast.compress") + System.clearProperty("spark.rdd.compress") + } + } + + test("block store put failure") { + // Use Java serializer so we can create an unserializable error. + store = new BlockManager("", actorSystem, master, new JavaSerializer, 1200) + + // The put should fail since a1 is not serializable. + class UnserializableClass + val a1 = new UnserializableClass + intercept[java.io.NotSerializableException] { + store.putSingle("a1", a1, StorageLevel.DISK_ONLY) + } + + // Make sure get a1 doesn't hang and returns None. + failAfter(1 second) { + assert(store.getSingle("a1") == None, "a1 should not be in store") + } + } +} diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala new file mode 100644 index 0000000000..3321fb5eb7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.util.{Failure, Success, Try} +import java.net.ServerSocket +import org.scalatest.FunSuite +import org.eclipse.jetty.server.Server + +class UISuite extends FunSuite { + test("jetty port increases under contention") { + val startPort = 3030 + val server = new Server(startPort) + server.start() + val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq()) + val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq()) + + // Allow some wiggle room in case ports on the machine are under contention + assert(boundPort1 > startPort && boundPort1 < startPort + 10) + assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10) + } + + test("jetty binds to port 0 correctly") { + val (jettyServer, boundPort) = JettyUtils.startJettyServer("localhost", 0, Seq()) + assert(jettyServer.getState === "STARTED") + assert(boundPort != 0) + Try {new ServerSocket(boundPort)} match { + case Success(s) => fail("Port %s doesn't seem used by jetty server".format(boundPort)) + case Failure (e) => + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala new file mode 100644 index 0000000000..63642461e4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers + +/** + * + */ + +class DistributionSuite extends FunSuite with ShouldMatchers { + test("summary") { + val d = new Distribution((1 to 100).toArray.map{_.toDouble}) + val stats = d.statCounter + stats.count should be (100) + stats.mean should be (50.5) + stats.sum should be (50 * 101) + + val quantiles = d.getQuantiles() + quantiles(0) should be (1) + quantiles(1) should be (26) + quantiles(2) should be (51) + quantiles(3) should be (76) + quantiles(4) should be (100) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/FakeClock.scala b/core/src/test/scala/org/apache/spark/util/FakeClock.scala new file mode 100644 index 0000000000..0a45917b08 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/FakeClock.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +class FakeClock extends Clock { + private var time = 0L + + def advance(millis: Long): Unit = time += millis + + def getTime(): Long = time +} diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala new file mode 100644 index 0000000000..45867463a5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import scala.collection.mutable.Buffer +import java.util.NoSuchElementException + +class NextIteratorSuite extends FunSuite with ShouldMatchers { + test("one iteration") { + val i = new StubIterator(Buffer(1)) + i.hasNext should be === true + i.next should be === 1 + i.hasNext should be === false + intercept[NoSuchElementException] { i.next() } + } + + test("two iterations") { + val i = new StubIterator(Buffer(1, 2)) + i.hasNext should be === true + i.next should be === 1 + i.hasNext should be === true + i.next should be === 2 + i.hasNext should be === false + intercept[NoSuchElementException] { i.next() } + } + + test("empty iteration") { + val i = new StubIterator(Buffer()) + i.hasNext should be === false + intercept[NoSuchElementException] { i.next() } + } + + test("close is called once for empty iterations") { + val i = new StubIterator(Buffer()) + i.hasNext should be === false + i.hasNext should be === false + i.closeCalled should be === 1 + } + + test("close is called once for non-empty iterations") { + val i = new StubIterator(Buffer(1, 2)) + i.next should be === 1 + i.next should be === 2 + // close isn't called until we check for the next element + i.closeCalled should be === 0 + i.hasNext should be === false + i.closeCalled should be === 1 + i.hasNext should be === false + i.closeCalled should be === 1 + } + + class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { + var closeCalled = 0 + + override def getNext() = { + if (ints.size == 0) { + finished = true + 0 + } else { + ints.remove(0) + } + } + + override def close() { + closeCalled += 1 + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala new file mode 100644 index 0000000000..a9dd0b1a5b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite +import java.io.ByteArrayOutputStream +import java.util.concurrent.TimeUnit._ + +class RateLimitedOutputStreamSuite extends FunSuite { + + private def benchmark[U](f: => U): Long = { + val start = System.nanoTime + f + System.nanoTime - start + } + + test("write") { + val underlying = new ByteArrayOutputStream + val data = "X" * 41000 + val stream = new RateLimitedOutputStream(underlying, 10000) + val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } + assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) + assert(underlying.toString("UTF-8") == data) + } +} diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala deleted file mode 100644 index 0af175f316..0000000000 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers -import collection.mutable -import java.util.Random -import scala.math.exp -import scala.math.signum -import spark.SparkContext._ - -class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { - - test ("basic accumulation"){ - sc = new SparkContext("local", "test") - val acc : Accumulator[Int] = sc.accumulator(0) - - val d = sc.parallelize(1 to 20) - d.foreach{x => acc += x} - acc.value should be (210) - - - val longAcc = sc.accumulator(0l) - val maxInt = Integer.MAX_VALUE.toLong - d.foreach{x => longAcc += maxInt + x} - longAcc.value should be (210l + maxInt * 20) - } - - test ("value not assignable from tasks") { - sc = new SparkContext("local", "test") - val acc : Accumulator[Int] = sc.accumulator(0) - - val d = sc.parallelize(1 to 20) - evaluating {d.foreach{x => acc.value = x}} should produce [Exception] - } - - test ("add value to collection accumulators") { - import SetAccum._ - val maxI = 1000 - for (nThreads <- List(1, 10)) { //test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) - val d = sc.parallelize(1 to maxI) - d.foreach { - x => acc += x - } - val v = acc.value.asInstanceOf[mutable.Set[Int]] - for (i <- 1 to maxI) { - v should contain(i) - } - resetSparkContext() - } - } - - implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] { - def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = { - t1 ++= t2 - t1 - } - def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = { - t1 += t2 - t1 - } - def zero(t: mutable.Set[Any]) : mutable.Set[Any] = { - new mutable.HashSet[Any]() - } - } - - test ("value not readable in tasks") { - import SetAccum._ - val maxI = 1000 - for (nThreads <- List(1, 10)) { //test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) - val d = sc.parallelize(1 to maxI) - evaluating { - d.foreach { - x => acc.value += x - } - } should produce [SparkException] - resetSparkContext() - } - } - - test ("collection accumulators") { - val maxI = 1000 - for (nThreads <- List(1, 10)) { - // test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val setAcc = sc.accumulableCollection(mutable.HashSet[Int]()) - val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]()) - val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]()) - val d = sc.parallelize((1 to maxI) ++ (1 to maxI)) - d.foreach { - x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} - } - - // Note that this is typed correctly -- no casts necessary - setAcc.value.size should be (maxI) - bufferAcc.value.size should be (2 * maxI) - mapAcc.value.size should be (maxI) - for (i <- 1 to maxI) { - setAcc.value should contain(i) - bufferAcc.value should contain(i) - mapAcc.value should contain (i -> i.toString) - } - resetSparkContext() - } - } - - test ("localValue readable in tasks") { - import SetAccum._ - val maxI = 1000 - for (nThreads <- List(1, 10)) { //test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) - val groupedInts = (1 to (maxI/20)).map {x => (20 * (x - 1) to 20 * x).toSet} - val d = sc.parallelize(groupedInts) - d.foreach { - x => acc.localValue ++= x - } - acc.value should be ( (0 to maxI).toSet) - resetSparkContext() - } - } - -} diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala deleted file mode 100644 index 785721ece8..0000000000 --- a/core/src/test/scala/spark/BroadcastSuite.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite - -class BroadcastSuite extends FunSuite with LocalSparkContext { - - test("basic broadcast") { - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) - } - - test("broadcast variables accessed in multiple threads") { - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) - } -} diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala deleted file mode 100644 index 966dede2be..0000000000 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ /dev/null @@ -1,392 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import java.io.File -import spark.rdd._ -import spark.SparkContext._ -import storage.StorageLevel - -class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { - initLogging() - - var checkpointDir: File = _ - val partitioner = new HashPartitioner(2) - - override def beforeEach() { - super.beforeEach() - checkpointDir = File.createTempFile("temp", "") - checkpointDir.delete() - sc = new SparkContext("local", "test") - sc.setCheckpointDir(checkpointDir.toString) - } - - override def afterEach() { - super.afterEach() - if (checkpointDir != null) { - checkpointDir.delete() - } - } - - test("basic checkpointing") { - val parCollection = sc.makeRDD(1 to 4) - val flatMappedRDD = parCollection.flatMap(x => 1 to x) - flatMappedRDD.checkpoint() - assert(flatMappedRDD.dependencies.head.rdd == parCollection) - val result = flatMappedRDD.collect() - assert(flatMappedRDD.dependencies.head.rdd != parCollection) - assert(flatMappedRDD.collect() === result) - } - - test("RDDs with one-to-one dependencies") { - testCheckpointing(_.map(x => x.toString)) - testCheckpointing(_.flatMap(x => 1 to x)) - testCheckpointing(_.filter(_ % 2 == 0)) - testCheckpointing(_.sample(false, 0.5, 0)) - testCheckpointing(_.glom()) - testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithIndexRDD(r, - (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false )) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) - testCheckpointing(_.pipe(Seq("cat"))) - } - - test("ParallelCollection") { - val parCollection = sc.makeRDD(1 to 4, 2) - val numPartitions = parCollection.partitions.size - parCollection.checkpoint() - assert(parCollection.dependencies === Nil) - val result = parCollection.collect() - assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) - assert(parCollection.dependencies != Nil) - assert(parCollection.partitions.length === numPartitions) - assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList) - assert(parCollection.collect() === result) - } - - test("BlockRDD") { - val blockId = "id" - val blockManager = SparkEnv.get.blockManager - blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) - val blockRDD = new BlockRDD[String](sc, Array(blockId)) - val numPartitions = blockRDD.partitions.size - blockRDD.checkpoint() - val result = blockRDD.collect() - assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) - assert(blockRDD.dependencies != Nil) - assert(blockRDD.partitions.length === numPartitions) - assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList) - assert(blockRDD.collect() === result) - } - - test("ShuffledRDD") { - testCheckpointing(rdd => { - // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD - new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) - }) - } - - test("UnionRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) - - // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed. - // Current implementation of UnionRDD has transient reference to parent RDDs, - // so only the partitions will reduce in serialized size, not the RDD. - testCheckpointing(_.union(otherRDD), false, true) - testParentCheckpointing(_.union(otherRDD), false, true) - } - - test("CartesianRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) - testCheckpointing(new CartesianRDD(sc, _, otherRDD)) - - // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed - // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, - // so only the RDD will reduce in serialized size, not the partitions. - testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false) - - // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after - // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. - // Note that this test is very specific to the current implementation of CartesianRDD. - val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint() // checkpoint that MappedRDD - val cartesian = new CartesianRDD(sc, ones, ones) - val splitBeforeCheckpoint = - serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) - cartesian.count() // do the checkpointing - val splitAfterCheckpoint = - serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) - assert( - (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) && - (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2), - "CartesianRDD.parents not updated after parent RDD checkpointed" - ) - } - - test("CoalescedRDD") { - testCheckpointing(_.coalesce(2)) - - // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed - // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, - // so only the RDD will reduce in serialized size, not the partitions. - testParentCheckpointing(_.coalesce(2), true, false) - - // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after - // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. - // Note that this test is very specific to the current implementation of CoalescedRDDPartitions - val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint() // checkpoint that MappedRDD - val coalesced = new CoalescedRDD(ones, 2) - val splitBeforeCheckpoint = - serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) - coalesced.count() // do the checkpointing - val splitAfterCheckpoint = - serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) - assert( - splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head, - "CoalescedRDDPartition.parents not updated after parent RDD checkpointed" - ) - } - - test("CoGroupedRDD") { - val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD() - testCheckpointing(rdd => { - CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner) - }, false, true) - - val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD() - testParentCheckpointing(rdd => { - CheckpointSuite.cogroup( - longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) - }, false, true) - } - - test("ZippedRDD") { - testCheckpointing( - rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) - - // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed - // Current implementation of ZippedRDDPartitions has transient references to parent RDDs, - // so only the RDD will reduce in serialized size, not the partitions. - testParentCheckpointing( - rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) - } - - test("CheckpointRDD with zero partitions") { - val rdd = new BlockRDD[Int](sc, Array[String]()) - assert(rdd.partitions.size === 0) - assert(rdd.isCheckpointed === false) - rdd.checkpoint() - assert(rdd.count() === 0) - assert(rdd.isCheckpointed === true) - assert(rdd.partitions.size === 0) - } - - /** - * Test checkpointing of the final RDD generated by the given operation. By default, - * this method tests whether the size of serialized RDD has reduced after checkpointing or not. - * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or - * not, but this is not done by default as usually the partitions do not refer to any RDD and - * therefore never store the lineage. - */ - def testCheckpointing[U: ClassManifest]( - op: (RDD[Int]) => RDD[U], - testRDDSize: Boolean = true, - testRDDPartitionSize: Boolean = false - ) { - // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD() - val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.headOption.orNull - val rddType = operatedRDD.getClass.getSimpleName - val numPartitions = operatedRDD.partitions.length - - // Find serialized sizes before and after the checkpoint - val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - operatedRDD.checkpoint() - val result = operatedRDD.collect() - val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - - // Test whether the checkpoint file has been created - assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) - - // Test whether dependencies have been changed from its earlier parent RDD - assert(operatedRDD.dependencies.head.rdd != parentRDD) - - // Test whether the partitions have been changed to the new Hadoop partitions - assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) - - // Test whether the number of partitions is same as before - assert(operatedRDD.partitions.length === numPartitions) - - // Test whether the data in the checkpointed RDD is same as original - assert(operatedRDD.collect() === result) - - // Test whether serialized size of the RDD has reduced. If the RDD - // does not have any dependency to another RDD (e.g., ParallelCollection, - // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. - if (testRDDSize) { - logInfo("Size of " + rddType + - "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing " + - "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - // Test whether serialized size of the partitions has reduced. If the partitions - // do not have any non-transient reference to another RDD or another RDD's partitions, it - // does not refer to a lineage and therefore may not reduce in size after checkpointing. - // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions - // must be forgotten after checkpointing (to remove all reference to parent RDDs) and - // replaced with the HadooPartitions of the checkpointed RDD. - if (testRDDPartitionSize) { - logInfo("Size of " + rddType + " partitions " - + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") - assert( - splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing " + - "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" - ) - } - } - - /** - * Test whether checkpointing of the parent of the generated RDD also - * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent - * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, - * this RDD will remember the partitions and therefore potentially the whole lineage. - */ - def testParentCheckpointing[U: ClassManifest]( - op: (RDD[Int]) => RDD[U], - testRDDSize: Boolean, - testRDDPartitionSize: Boolean - ) { - // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD() - val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.head.rdd - val rddType = operatedRDD.getClass.getSimpleName - val parentRDDType = parentRDD.getClass.getSimpleName - - // Get the partitions and dependencies of the parent in case they're lazily computed - parentRDD.dependencies - parentRDD.partitions - - // Find serialized sizes before and after the checkpoint - val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one - val result = operatedRDD.collect() - val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - - // Test whether the data in the checkpointed RDD is same as original - assert(operatedRDD.collect() === result) - - // Test whether serialized size of the RDD has reduced because of its parent being - // checkpointed. If this RDD or its parent RDD do not have any dependency - // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may - // not reduce in size after checkpointing. - if (testRDDSize) { - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType + - "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - // Test whether serialized size of the partitions has reduced because of its parent being - // checkpointed. If the partitions do not have any non-transient reference to another RDD - // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce - // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent - // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's - // partitions must have changed after checkpointing. - if (testRDDPartitionSize) { - assert( - splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType + - "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" - ) - } - - } - - /** - * Generate an RDD with a long lineage of one-to-one dependencies. - */ - def generateLongLineageRDD(): RDD[Int] = { - var rdd = sc.makeRDD(1 to 100, 4) - for (i <- 1 to 50) { - rdd = rdd.map(x => x + 1) - } - rdd - } - - /** - * Generate an RDD with a long lineage specifically for CoGroupedRDD. - * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage - * and narrow dependency with this RDD. This method generate such an RDD by a sequence - * of cogroups and mapValues which creates a long lineage of narrow dependencies. - */ - def generateLongLineageRDDForCoGroupedRDD() = { - val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) - - def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) - - var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones) - for(i <- 1 to 10) { - cogrouped = cogrouped.mapValues(add).cogroup(ones) - } - cogrouped.mapValues(add) - } - - /** - * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks - * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. - */ - def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length, - Utils.serialize(rdd.partitions).length) - } - - /** - * Serialize and deserialize an object. This is useful to verify the objects - * contents after deserialization (e.g., the contents of an RDD split after - * it is sent to a slave along with a task) - */ - def serializeDeserialize[T](obj: T): T = { - val bytes = Utils.serialize(obj) - Utils.deserialize[T](bytes) - } -} - - -object CheckpointSuite { - // This is a custom cogroup function that does not use mapValues like - // the PairRDDFunctions.cogroup() - def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { - //println("First = " + first + ", second = " + second) - new CoGroupedRDD[K]( - Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), - part - ).asInstanceOf[RDD[(K, Seq[Seq[V]])]] - } - -} diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala deleted file mode 100644 index 7d2831e19c..0000000000 --- a/core/src/test/scala/spark/ClosureCleanerSuite.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.NotSerializableException - -import org.scalatest.FunSuite -import spark.LocalSparkContext._ -import SparkContext._ - -class ClosureCleanerSuite extends FunSuite { - test("closures inside an object") { - assert(TestObject.run() === 30) // 6 + 7 + 8 + 9 - } - - test("closures inside a class") { - val obj = new TestClass - assert(obj.run() === 30) // 6 + 7 + 8 + 9 - } - - test("closures inside a class with no default constructor") { - val obj = new TestClassWithoutDefaultConstructor(5) - assert(obj.run() === 30) // 6 + 7 + 8 + 9 - } - - test("closures that don't use fields of the outer class") { - val obj = new TestClassWithoutFieldAccess - assert(obj.run() === 30) // 6 + 7 + 8 + 9 - } - - test("nested closures inside an object") { - assert(TestObjectWithNesting.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 - } - - test("nested closures inside a class") { - val obj = new TestClassWithNesting(1) - assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 - } -} - -// A non-serializable class we create in closures to make sure that we aren't -// keeping references to unneeded variables from our outer closures. -class NonSerializable {} - -object TestObject { - def run(): Int = { - var nonSer = new NonSerializable - var x = 5 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + x).reduce(_ + _) - } - } -} - -class TestClass extends Serializable { - var x = 5 - - def getX = x - - def run(): Int = { - var nonSer = new NonSerializable - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + getX).reduce(_ + _) - } - } -} - -class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { - def getX = x - - def run(): Int = { - var nonSer = new NonSerializable - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + getX).reduce(_ + _) - } - } -} - -// This class is not serializable, but we aren't using any of its fields in our -// closures, so they won't have a $outer pointing to it and should still work. -class TestClassWithoutFieldAccess { - var nonSer = new NonSerializable - - def run(): Int = { - var nonSer2 = new NonSerializable - var x = 5 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + x).reduce(_ + _) - } - } -} - - -object TestObjectWithNesting { - def run(): Int = { - var nonSer = new NonSerializable - var answer = 0 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - var y = 1 - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + y).reduce(_ + _) - } - answer - } - } -} - -class TestClassWithNesting(val y: Int) extends Serializable { - def getY = y - - def run(): Int = { - var nonSer = new NonSerializable - var answer = 0 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + getY).reduce(_ + _) - } - answer - } - } -} diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala deleted file mode 100644 index e11efe459c..0000000000 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ /dev/null @@ -1,362 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import network.ConnectionManagerId -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.matchers.ShouldMatchers -import org.scalatest.prop.Checkers -import org.scalatest.time.{Span, Millis} -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ -import org.eclipse.jetty.server.{Server, Request, Handler} - -import com.google.common.io.Files - -import scala.collection.mutable.ArrayBuffer - -import SparkContext._ -import storage.{GetBlock, BlockManagerWorker, StorageLevel} -import ui.JettyUtils - - -class NotSerializableClass -class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} - - -class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter - with LocalSparkContext { - - val clusterUrl = "local-cluster[2,1,512]" - - after { - System.clearProperty("spark.reducer.maxMbInFlight") - System.clearProperty("spark.storage.memoryFraction") - } - - test("task throws not serializable exception") { - // Ensures that executors do not crash when an exn is not serializable. If executors crash, - // this test will hang. Correct behavior is that executors don't crash but fail tasks - // and the scheduler throws a SparkException. - - // numSlaves must be less than numPartitions - val numSlaves = 3 - val numPartitions = 10 - - sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") - val data = sc.parallelize(1 to 100, numPartitions). - map(x => throw new NotSerializableExn(new NotSerializableClass)) - intercept[SparkException] { - data.count() - } - resetSparkContext() - } - - test("local-cluster format") { - sc = new SparkContext("local-cluster[2,1,512]", "test") - assert(sc.parallelize(1 to 2, 2).count() == 2) - resetSparkContext() - sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") - assert(sc.parallelize(1 to 2, 2).count() == 2) - resetSparkContext() - sc = new SparkContext("local-cluster[2, 1, 512]", "test") - assert(sc.parallelize(1 to 2, 2).count() == 2) - resetSparkContext() - sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") - assert(sc.parallelize(1 to 2, 2).count() == 2) - resetSparkContext() - } - - test("simple groupByKey") { - sc = new SparkContext(clusterUrl, "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 5) - val groups = pairs.groupByKey(5).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey where map output sizes exceed maxMbInFlight") { - System.setProperty("spark.reducer.maxMbInFlight", "1") - sc = new SparkContext(clusterUrl, "test") - // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output - // file should be about 2.5 MB - val pairs = sc.parallelize(1 to 2000, 4).map(x => (x % 16, new Array[Byte](10000))) - val groups = pairs.groupByKey(2).map(x => (x._1, x._2.size)).collect() - assert(groups.length === 16) - assert(groups.map(_._2).sum === 2000) - // Note that spark.reducer.maxMbInFlight will be cleared in the test suite's after{} block - } - - test("accumulators") { - sc = new SparkContext(clusterUrl, "test") - val accum = sc.accumulator(0) - sc.parallelize(1 to 10, 10).foreach(x => accum += x) - assert(accum.value === 55) - } - - test("broadcast variables") { - sc = new SparkContext(clusterUrl, "test") - val array = new Array[Int](100) - val bv = sc.broadcast(array) - array(2) = 3 // Change the array -- this should not be seen on workers - val rdd = sc.parallelize(1 to 10, 10) - val sum = rdd.map(x => bv.value.sum).reduce(_ + _) - assert(sum === 0) - } - - test("repeatedly failing task") { - sc = new SparkContext(clusterUrl, "test") - val accum = sc.accumulator(0) - val thrown = intercept[SparkException] { - sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) - } - assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("more than 4 times")) - } - - test("caching") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).cache() - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching on disk") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory, serialized, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching on disk, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory and disk, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory and disk, serialized, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2) - - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - - // Get all the locations of the first partition and try to fetch the partitions - // from those locations. - val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray - val blockId = blockIds(0) - val blockManager = SparkEnv.get.blockManager - blockManager.master.getLocations(blockId).foreach(id => { - val bytes = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(id.host, id.port)) - val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList - assert(deserialized === (1 to 100).toList) - }) - } - - test("compute without caching when no partitions fit in memory") { - System.setProperty("spark.storage.memoryFraction", "0.0001") - sc = new SparkContext(clusterUrl, "test") - // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache - // to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory - val data = sc.parallelize(1 to 4000000, 2).persist(StorageLevel.MEMORY_ONLY_SER) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - System.clearProperty("spark.storage.memoryFraction") - } - - test("compute when only some partitions fit in memory") { - System.setProperty("spark.storage.memoryFraction", "0.01") - sc = new SparkContext(clusterUrl, "test") - // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache - // to only 5 MB (0.01 of 512 MB), so not all of it will fit in memory; we use 20 partitions - // to make sure that *some* of them do fit though - val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - System.clearProperty("spark.storage.memoryFraction") - } - - test("passing environment variables to cluster") { - sc = new SparkContext(clusterUrl, "test", null, Nil, Map("TEST_VAR" -> "TEST_VALUE")) - val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect() - assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE")) - } - - test("recover from node failures") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} - DistributedSuite.amMaster = true - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(Seq(true, true), 2) - assert(data.count === 2) // force executors to start - assert(data.map(markNodeIfIdentity).collect.size === 2) - assert(data.map(failOnMarkedIdentity).collect.size === 2) - } - - test("recover from repeated node failures during shuffle-map") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} - DistributedSuite.amMaster = true - sc = new SparkContext(clusterUrl, "test") - for (i <- 1 to 3) { - val data = sc.parallelize(Seq(true, false), 2) - assert(data.count === 2) - assert(data.map(markNodeIfIdentity).collect.size === 2) - assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) - } - } - - test("recover from repeated node failures during shuffle-reduce") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} - DistributedSuite.amMaster = true - sc = new SparkContext(clusterUrl, "test") - for (i <- 1 to 3) { - val data = sc.parallelize(Seq(true, true), 2) - assert(data.count === 2) - assert(data.map(markNodeIfIdentity).collect.size === 2) - // This relies on mergeCombiners being used to perform the actual reduce for this - // test to actually be testing what it claims. - val grouped = data.map(x => x -> x).combineByKey( - x => x, - (x: Boolean, y: Boolean) => x, - (x: Boolean, y: Boolean) => failOnMarkedIdentity(x) - ) - assert(grouped.collect.size === 1) - } - } - - test("recover from node failures with replication") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} - DistributedSuite.amMaster = true - // Using more than two nodes so we don't have a symmetric communication pattern and might - // cache a partially correct list of peers. - sc = new SparkContext("local-cluster[3,1,512]", "test") - for (i <- 1 to 3) { - val data = sc.parallelize(Seq(true, false, false, false), 4) - data.persist(StorageLevel.MEMORY_ONLY_2) - - assert(data.count === 4) - assert(data.map(markNodeIfIdentity).collect.size === 4) - assert(data.map(failOnMarkedIdentity).collect.size === 4) - - // Create a new replicated RDD to make sure that cached peer information doesn't cause - // problems. - val data2 = sc.parallelize(Seq(true, true), 2).persist(StorageLevel.MEMORY_ONLY_2) - assert(data2.count === 2) - } - } - - test("unpersist RDDs") { - DistributedSuite.amMaster = true - sc = new SparkContext("local-cluster[3,1,512]", "test") - val data = sc.parallelize(Seq(true, false, false, false), 4) - data.persist(StorageLevel.MEMORY_ONLY_2) - data.count - assert(sc.persistentRdds.isEmpty === false) - data.unpersist() - assert(sc.persistentRdds.isEmpty === true) - - failAfter(Span(3000, Millis)) { - try { - while (! sc.getRDDStorageInfo.isEmpty) { - Thread.sleep(200) - } - } catch { - case _ => { Thread.sleep(10) } - // Do nothing. We might see exceptions because block manager - // is racing this thread to remove entries from the driver. - } - } - } - - test("job should fail if TaskResult exceeds Akka frame size") { - // We must use local-cluster mode since results are returned differently - // when running under LocalScheduler: - sc = new SparkContext("local-cluster[1,1,512]", "test") - val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt - val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)} - val exception = intercept[SparkException] { - rdd.reduce((x, y) => x) - } - exception.getMessage should endWith("result exceeded Akka frame size") - } -} - -object DistributedSuite { - // Indicates whether this JVM is marked for failure. - var mark = false - - // Set by test to remember if we are in the driver program so we can assert - // that we are not. - var amMaster = false - - // Act like an identity function, but if the argument is true, set mark to true. - def markNodeIfIdentity(item: Boolean): Boolean = { - if (item) { - assert(!amMaster) - mark = true - } - item - } - - // Act like an identity function, but if mark was set to true previously, fail, - // crashing the entire JVM. - def failOnMarkedIdentity(item: Boolean): Boolean = { - if (mark) { - System.exit(42) - } - item - } -} diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala deleted file mode 100644 index 553c0309f6..0000000000 --- a/core/src/test/scala/spark/DriverSuite.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.File - -import org.apache.log4j.Logger -import org.apache.log4j.Level - -import org.scalatest.FunSuite -import org.scalatest.concurrent.Timeouts -import org.scalatest.prop.TableDrivenPropertyChecks._ -import org.scalatest.time.SpanSugar._ - -class DriverSuite extends FunSuite with Timeouts { - test("driver should exit after finishing") { - assert(System.getenv("SPARK_HOME") != null) - // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" - val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) - forAll(masters) { (master: String) => - failAfter(30 seconds) { - Utils.execute(Seq("./spark-class", "spark.DriverWithoutCleanup", master), - new File(System.getenv("SPARK_HOME"))) - } - } - } -} - -/** - * Program that creates a Spark driver but doesn't call SparkContext.stop() or - * Sys.exit() after finishing. - */ -object DriverWithoutCleanup { - def main(args: Array[String]) { - Logger.getRootLogger().setLevel(Level.WARN) - val sc = new SparkContext(args(0), "DriverWithoutCleanup") - sc.parallelize(1 to 100, 4).count() - } -} diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala deleted file mode 100644 index 5b133cdd6e..0000000000 --- a/core/src/test/scala/spark/FailureSuite.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite - -import SparkContext._ - -// Common state shared by FailureSuite-launched tasks. We use a global object -// for this because any local variables used in the task closures will rightfully -// be copied for each task, so there's no other way for them to share state. -object FailureSuiteState { - var tasksRun = 0 - var tasksFailed = 0 - - def clear() { - synchronized { - tasksRun = 0 - tasksFailed = 0 - } - } -} - -class FailureSuite extends FunSuite with LocalSparkContext { - - // Run a 3-task map job in which task 1 deterministically fails once, and check - // whether the job completes successfully and we ran 4 tasks in total. - test("failure in a single-stage job") { - sc = new SparkContext("local[1,1]", "test") - val results = sc.makeRDD(1 to 3, 3).map { x => - FailureSuiteState.synchronized { - FailureSuiteState.tasksRun += 1 - if (x == 1 && FailureSuiteState.tasksFailed == 0) { - FailureSuiteState.tasksFailed += 1 - throw new Exception("Intentional task failure") - } - } - x * x - }.collect() - FailureSuiteState.synchronized { - assert(FailureSuiteState.tasksRun === 4) - } - assert(results.toList === List(1,4,9)) - FailureSuiteState.clear() - } - - // Run a map-reduce job in which a reduce task deterministically fails once. - test("failure in a two-stage job") { - sc = new SparkContext("local[1,1]", "test") - val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { - case (k, v) => - FailureSuiteState.synchronized { - FailureSuiteState.tasksRun += 1 - if (k == 1 && FailureSuiteState.tasksFailed == 0) { - FailureSuiteState.tasksFailed += 1 - throw new Exception("Intentional task failure") - } - } - (k, v(0) * v(0)) - }.collect() - FailureSuiteState.synchronized { - assert(FailureSuiteState.tasksRun === 4) - } - assert(results.toSet === Set((1, 1), (2, 4), (3, 9))) - FailureSuiteState.clear() - } - - test("failure because task results are not serializable") { - sc = new SparkContext("local[1,1]", "test") - val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) - - val thrown = intercept[SparkException] { - results.collect() - } - assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException")) - - FailureSuiteState.clear() - } - - test("failure because task closure is not serializable") { - sc = new SparkContext("local[1,1]", "test") - val a = new NonSerializable - - // Non-serializable closure in the final result stage - val thrown = intercept[SparkException] { - sc.parallelize(1 to 10, 2).map(x => a).count() - } - assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException")) - - // Non-serializable closure in an earlier stage - val thrown1 = intercept[SparkException] { - sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() - } - assert(thrown1.getClass === classOf[SparkException]) - assert(thrown1.getMessage.contains("NotSerializableException")) - - // Non-serializable closure in foreach function - val thrown2 = intercept[SparkException] { - sc.parallelize(1 to 10, 2).foreach(x => println(a)) - } - assert(thrown2.getClass === classOf[SparkException]) - assert(thrown2.getMessage.contains("NotSerializableException")) - - FailureSuiteState.clear() - } - - // TODO: Need to add tests with shuffle fetch failures. -} - - diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala deleted file mode 100644 index 242ae971f8..0000000000 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import com.google.common.io.Files -import org.scalatest.FunSuite -import java.io.{File, PrintWriter, FileReader, BufferedReader} -import SparkContext._ - -class FileServerSuite extends FunSuite with LocalSparkContext { - - @transient var tmpFile: File = _ - @transient var testJarFile: File = _ - - override def beforeEach() { - super.beforeEach() - // Create a sample text file - val tmpdir = new File(Files.createTempDir(), "test") - tmpdir.mkdir() - tmpFile = new File(tmpdir, "FileServerSuite.txt") - val pw = new PrintWriter(tmpFile) - pw.println("100") - pw.close() - } - - override def afterEach() { - super.afterEach() - // Clean up downloaded file - if (tmpFile.exists) { - tmpFile.delete() - } - } - - test("Distributing files locally") { - sc = new SparkContext("local[4]", "test") - sc.addFile(tmpFile.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) - } - - test("Distributing files locally using URL as input") { - // addFile("file:///....") - sc = new SparkContext("local[4]", "test") - sc.addFile(new File(tmpFile.toString).toURI.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) - } - - test ("Dynamically adding JARS locally") { - sc = new SparkContext("local[4]", "test") - val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() - sc.addJar(sampleJarFile) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) - val result = sc.parallelize(testData).reduceByKey { (x,y) => - val fac = Thread.currentThread.getContextClassLoader() - .loadClass("org.uncommons.maths.Maths") - .getDeclaredMethod("factorial", classOf[Int]) - val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - a + b - }.collect() - assert(result.toSet === Set((1,2), (2,7), (3,121))) - } - - test("Distributing files on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test") - sc.addFile(tmpFile.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) - } - - test ("Dynamically adding JARS on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test") - val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() - sc.addJar(sampleJarFile) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) - val result = sc.parallelize(testData).reduceByKey { (x,y) => - val fac = Thread.currentThread.getContextClassLoader() - .loadClass("org.uncommons.maths.Maths") - .getDeclaredMethod("factorial", classOf[Int]) - val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - a + b - }.collect() - assert(result.toSet === Set((1,2), (2,7), (3,121))) - } -} diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala deleted file mode 100644 index 1e2c257c4b..0000000000 --- a/core/src/test/scala/spark/FileSuite.scala +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.{FileWriter, PrintWriter, File} - -import scala.io.Source - -import com.google.common.io.Files -import org.scalatest.FunSuite -import org.apache.hadoop.io._ -import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec} - - -import SparkContext._ - -class FileSuite extends FunSuite with LocalSparkContext { - - test("text files") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 4) - nums.saveAsTextFile(outputDir) - // Read the plain text file and check it's OK - val outputFile = new File(outputDir, "part-00000") - val content = Source.fromFile(outputFile).mkString - assert(content === "1\n2\n3\n4\n") - // Also try reading it in as a text file RDD - assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) - } - - test("text files (compressed)") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val normalDir = new File(tempDir, "output_normal").getAbsolutePath - val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath - val codec = new DefaultCodec() - - val data = sc.parallelize("a" * 10000, 1) - data.saveAsTextFile(normalDir) - data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec]) - - val normalFile = new File(normalDir, "part-00000") - val normalContent = sc.textFile(normalDir).collect - assert(normalContent === Array.fill(10000)("a")) - - val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) - val compressedContent = sc.textFile(compressedOutputDir).collect - assert(compressedContent === Array.fill(10000)("a")) - - assert(compressedFile.length < normalFile.length) - } - - test("SequenceFiles") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa) - nums.saveAsSequenceFile(outputDir) - // Try reading the output back as a SequenceFile - val output = sc.sequenceFile[IntWritable, Text](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("SequenceFile (compressed)") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val normalDir = new File(tempDir, "output_normal").getAbsolutePath - val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath - val codec = new DefaultCodec() - - val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x)) - data.saveAsSequenceFile(normalDir) - data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec])) - - val normalFile = new File(normalDir, "part-00000") - val normalContent = sc.sequenceFile[String, String](normalDir).collect - assert(normalContent === Array.fill(100)("abc", "abc")) - - val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) - val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect - assert(compressedContent === Array.fill(100)("abc", "abc")) - - assert(compressedFile.length < normalFile.length) - } - - test("SequenceFile with writable key") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x)) - nums.saveAsSequenceFile(outputDir) - // Try reading the output back as a SequenceFile - val output = sc.sequenceFile[IntWritable, Text](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("SequenceFile with writable value") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (x, new Text("a" * x))) - nums.saveAsSequenceFile(outputDir) - // Try reading the output back as a SequenceFile - val output = sc.sequenceFile[IntWritable, Text](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("SequenceFile with writable key and value") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) - nums.saveAsSequenceFile(outputDir) - // Try reading the output back as a SequenceFile - val output = sc.sequenceFile[IntWritable, Text](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("implicit conversions in reading SequenceFiles") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa) - nums.saveAsSequenceFile(outputDir) - // Similar to the tests above, we read a SequenceFile, but this time we pass type params - // that are convertable to Writable instead of calling sequenceFile[IntWritable, Text] - val output1 = sc.sequenceFile[Int, String](outputDir) - assert(output1.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) - // Also try having one type be a subclass of Writable and one not - val output2 = sc.sequenceFile[Int, Text](outputDir) - assert(output2.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - val output3 = sc.sequenceFile[IntWritable, String](outputDir) - assert(output3.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("object files of ints") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 4) - nums.saveAsObjectFile(outputDir) - // Try reading the output back as an object file - val output = sc.objectFile[Int](outputDir) - assert(output.collect().toList === List(1, 2, 3, 4)) - } - - test("object files of complex types") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) - nums.saveAsObjectFile(outputDir) - // Try reading the output back as an object file - val output = sc.objectFile[(Int, String)](outputDir) - assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) - } - - test("write SequenceFile using new Hadoop API") { - import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) - nums.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, Text]]( - outputDir) - val output = sc.sequenceFile[IntWritable, Text](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("read SequenceFile using new Hadoop API") { - import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) - nums.saveAsSequenceFile(outputDir) - val output = - sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("file caching") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val out = new FileWriter(tempDir + "/input") - out.write("Hello world!\n") - out.write("What's up?\n") - out.write("Goodbye\n") - out.close() - val rdd = sc.textFile(tempDir + "/input").cache() - assert(rdd.count() === 3) - assert(rdd.count() === 3) - assert(rdd.count() === 3) - } -} diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java deleted file mode 100644 index c337c49268..0000000000 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ /dev/null @@ -1,865 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark; - -import java.io.File; -import java.io.IOException; -import java.io.Serializable; -import java.util.*; - -import com.google.common.base.Optional; -import scala.Tuple2; - -import com.google.common.base.Charsets; -import org.apache.hadoop.io.compress.DefaultCodec; -import com.google.common.io.Files; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapred.SequenceFileInputFormat; -import org.apache.hadoop.mapred.SequenceFileOutputFormat; -import org.apache.hadoop.mapreduce.Job; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaDoubleRDD; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.*; -import spark.partial.BoundedDouble; -import spark.partial.PartialResult; -import spark.storage.StorageLevel; -import spark.util.StatCounter; - - -// The test suite itself is Serializable so that anonymous Function implementations can be -// serialized, as an alternative to converting these anonymous classes to static inner classes; -// see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaAPISuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port"); - } - - static class ReverseIntComparator implements Comparator, Serializable { - - @Override - public int compare(Integer a, Integer b) { - if (a > b) return -1; - else if (a < b) return 1; - else return 0; - } - }; - - @Test - public void sparkContextUnion() { - // Union of non-specialized JavaRDDs - List strings = Arrays.asList("Hello", "World"); - JavaRDD s1 = sc.parallelize(strings); - JavaRDD s2 = sc.parallelize(strings); - // Varargs - JavaRDD sUnion = sc.union(s1, s2); - Assert.assertEquals(4, sUnion.count()); - // List - List> list = new ArrayList>(); - list.add(s2); - sUnion = sc.union(s1, list); - Assert.assertEquals(4, sUnion.count()); - - // Union of JavaDoubleRDDs - List doubles = Arrays.asList(1.0, 2.0); - JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles); - JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles); - JavaDoubleRDD dUnion = sc.union(d1, d2); - Assert.assertEquals(4, dUnion.count()); - - // Union of JavaPairRDDs - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); - JavaPairRDD p1 = sc.parallelizePairs(pairs); - JavaPairRDD p2 = sc.parallelizePairs(pairs); - JavaPairRDD pUnion = sc.union(p1, p2); - Assert.assertEquals(4, pUnion.count()); - } - - @Test - public void sortByKey() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); - - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - // Default comparator - JavaPairRDD sortedRDD = rdd.sortByKey(); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); - List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); - - // Custom comparator - sortedRDD = rdd.sortByKey(new ReverseIntComparator(), false); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); - sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); - } - - static int foreachCalls = 0; - - @Test - public void foreach() { - foreachCalls = 0; - JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach(new VoidFunction() { - @Override - public void call(String s) { - foreachCalls++; - } - }); - Assert.assertEquals(2, foreachCalls); - } - - @Test - public void lookup() { - JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") - )); - Assert.assertEquals(2, categories.lookup("Oranges").size()); - Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size()); - } - - @Test - public void groupBy() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function isOdd = new Function() { - @Override - public Boolean call(Integer x) { - return x % 2 == 0; - } - }; - JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); - Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens - Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds - - oddsAndEvens = rdd.groupBy(isOdd, 1); - Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens - Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds - } - - @Test - public void cogroup() { - JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") - )); - JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) - )); - JavaPairRDD, List>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", cogrouped.lookup("Oranges").get(0)._1().toString()); - Assert.assertEquals("[2]", cogrouped.lookup("Oranges").get(0)._2().toString()); - - cogrouped.collect(); - } - - @Test - public void leftOuterJoin() { - JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 2), - new Tuple2(2, 1), - new Tuple2(3, 1) - )); - JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 'x'), - new Tuple2(2, 'y'), - new Tuple2(2, 'z'), - new Tuple2(4, 'w') - )); - List>>> joined = - rdd1.leftOuterJoin(rdd2).collect(); - Assert.assertEquals(5, joined.size()); - Tuple2>> firstUnmatched = - rdd1.leftOuterJoin(rdd2).filter( - new Function>>, Boolean>() { - @Override - public Boolean call(Tuple2>> tup) - throws Exception { - return !tup._2()._2().isPresent(); - } - }).first(); - Assert.assertEquals(3, firstUnmatched._1().intValue()); - } - - @Test - public void foldReduce() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function2 add = new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }; - - int sum = rdd.fold(0, add); - Assert.assertEquals(33, sum); - - sum = rdd.reduce(add); - Assert.assertEquals(33, sum); - } - - @Test - public void foldByKey() { - List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - JavaPairRDD sums = rdd.foldByKey(0, - new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); - Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); - Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); - Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); - } - - @Test - public void reduceByKey() { - List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - JavaPairRDD counts = rdd.reduceByKey( - new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); - Assert.assertEquals(1, counts.lookup(1).get(0).intValue()); - Assert.assertEquals(2, counts.lookup(2).get(0).intValue()); - Assert.assertEquals(3, counts.lookup(3).get(0).intValue()); - - Map localCounts = counts.collectAsMap(); - Assert.assertEquals(1, localCounts.get(1).intValue()); - Assert.assertEquals(2, localCounts.get(2).intValue()); - Assert.assertEquals(3, localCounts.get(3).intValue()); - - localCounts = rdd.reduceByKeyLocally(new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); - Assert.assertEquals(1, localCounts.get(1).intValue()); - Assert.assertEquals(2, localCounts.get(2).intValue()); - Assert.assertEquals(3, localCounts.get(3).intValue()); - } - - @Test - public void approximateResults() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Map countsByValue = rdd.countByValue(); - Assert.assertEquals(2, countsByValue.get(1).longValue()); - Assert.assertEquals(1, countsByValue.get(13).longValue()); - - PartialResult> approx = rdd.countByValueApprox(1); - Map finalValue = approx.getFinalValue(); - Assert.assertEquals(2.0, finalValue.get(1).mean(), 0.01); - Assert.assertEquals(1.0, finalValue.get(13).mean(), 0.01); - } - - @Test - public void take() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Assert.assertEquals(1, rdd.first().intValue()); - List firstTwo = rdd.take(2); - List sample = rdd.takeSample(false, 2, 42); - } - - @Test - public void cartesian() { - JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); - JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); - JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); - Assert.assertEquals(new Tuple2("Hello", 1.0), cartesian.first()); - } - - @Test - public void javaDoubleRDD() { - JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); - JavaDoubleRDD distinct = rdd.distinct(); - Assert.assertEquals(5, distinct.count()); - JavaDoubleRDD filter = rdd.filter(new Function() { - @Override - public Boolean call(Double x) { - return x > 2.0; - } - }); - Assert.assertEquals(3, filter.count()); - JavaDoubleRDD union = rdd.union(rdd); - Assert.assertEquals(12, union.count()); - union = union.cache(); - Assert.assertEquals(12, union.count()); - - Assert.assertEquals(20, rdd.sum(), 0.01); - StatCounter stats = rdd.stats(); - Assert.assertEquals(20, stats.sum(), 0.01); - Assert.assertEquals(20/6.0, rdd.mean(), 0.01); - Assert.assertEquals(20/6.0, rdd.mean(), 0.01); - Assert.assertEquals(6.22222, rdd.variance(), 0.01); - Assert.assertEquals(7.46667, rdd.sampleVariance(), 0.01); - Assert.assertEquals(2.49444, rdd.stdev(), 0.01); - Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01); - - Double first = rdd.first(); - List take = rdd.take(5); - } - - @Test - public void map() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.map(new DoubleFunction() { - @Override - public Double call(Integer x) { - return 1.0 * x; - } - }).cache(); - JavaPairRDD pairs = rdd.map(new PairFunction() { - @Override - public Tuple2 call(Integer x) { - return new Tuple2(x, x); - } - }).cache(); - JavaRDD strings = rdd.map(new Function() { - @Override - public String call(Integer x) { - return x.toString(); - } - }).cache(); - } - - @Test - public void flatMap() { - JavaRDD rdd = sc.parallelize(Arrays.asList("Hello World!", - "The quick brown fox jumps over the lazy dog.")); - JavaRDD words = rdd.flatMap(new FlatMapFunction() { - @Override - public Iterable call(String x) { - return Arrays.asList(x.split(" ")); - } - }); - Assert.assertEquals("Hello", words.first()); - Assert.assertEquals(11, words.count()); - - JavaPairRDD pairs = rdd.flatMap( - new PairFlatMapFunction() { - - @Override - public Iterable> call(String s) { - List> pairs = new LinkedList>(); - for (String word : s.split(" ")) pairs.add(new Tuple2(word, word)); - return pairs; - } - } - ); - Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); - Assert.assertEquals(11, pairs.count()); - - JavaDoubleRDD doubles = rdd.flatMap(new DoubleFlatMapFunction() { - @Override - public Iterable call(String s) { - List lengths = new LinkedList(); - for (String word : s.split(" ")) lengths.add(word.length() * 1.0); - return lengths; - } - }); - Double x = doubles.first(); - Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01); - Assert.assertEquals(11, pairs.count()); - } - - @Test - public void mapsFromPairsToPairs() { - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - - // Regression test for SPARK-668: - JavaPairRDD swapped = pairRDD.flatMap( - new PairFlatMapFunction, String, Integer>() { - @Override - public Iterable> call(Tuple2 item) throws Exception { - return Collections.singletonList(item.swap()); - } - }); - swapped.collect(); - - // There was never a bug here, but it's worth testing: - pairRDD.map(new PairFunction, String, Integer>() { - @Override - public Tuple2 call(Tuple2 item) throws Exception { - return item.swap(); - } - }).collect(); - } - - @Test - public void mapPartitions() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); - JavaRDD partitionSums = rdd.mapPartitions( - new FlatMapFunction, Integer>() { - @Override - public Iterable call(Iterator iter) { - int sum = 0; - while (iter.hasNext()) { - sum += iter.next(); - } - return Collections.singletonList(sum); - } - }); - Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); - } - - @Test - public void persist() { - JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); - doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY()); - Assert.assertEquals(20, doubleRDD.sum(), 0.1); - - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); - Assert.assertEquals("a", pairRDD.first()._2()); - - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - rdd = rdd.persist(StorageLevel.DISK_ONLY()); - Assert.assertEquals(1, rdd.first().intValue()); - } - - @Test - public void iterator() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, null); - Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); - } - - @Test - public void glom() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); - Assert.assertEquals("[1, 2]", rdd.glom().first().toString()); - } - - // File input / output tests are largely adapted from FileSuite: - - @Test - public void textFiles() throws IOException { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); - rdd.saveAsTextFile(outputDir); - // Read the plain text file and check it's OK - File outputFile = new File(outputDir, "part-00000"); - String content = Files.toString(outputFile, Charsets.UTF_8); - Assert.assertEquals("1\n2\n3\n4\n", content); - // Also try reading it in as a text file RDD - List expected = Arrays.asList("1", "2", "3", "4"); - JavaRDD readRDD = sc.textFile(outputDir); - Assert.assertEquals(expected, readRDD.collect()); - } - - @Test - public void textFilesCompressed() throws IOException { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); - rdd.saveAsTextFile(outputDir, DefaultCodec.class); - - // Try reading it in as a text file RDD - List expected = Arrays.asList("1", "2", "3", "4"); - JavaRDD readRDD = sc.textFile(outputDir); - Assert.assertEquals(expected, readRDD.collect()); - } - - @Test - public void sequenceFile() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.map(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); - - // Try reading the output back as an object file - JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, - Text.class).map(new PairFunction, Integer, String>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(pair._1().get(), pair._2().toString()); - } - }); - Assert.assertEquals(pairs, readRDD.collect()); - } - - @Test - public void writeWithNewAPIHadoopFile() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.map(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, - org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); - - JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, - Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); - } - - @Test - public void readWithNewAPIHadoopFile() throws IOException { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.map(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); - - JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, - Text.class, new Job().getConfiguration()); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); - } - - @Test - public void objectFilesOfInts() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); - rdd.saveAsObjectFile(outputDir); - // Try reading the output back as an object file - List expected = Arrays.asList(1, 2, 3, 4); - JavaRDD readRDD = sc.objectFile(outputDir); - Assert.assertEquals(expected, readRDD.collect()); - } - - @Test - public void objectFilesOfComplexTypes() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.saveAsObjectFile(outputDir); - // Try reading the output back as an object file - JavaRDD> readRDD = sc.objectFile(outputDir); - Assert.assertEquals(pairs, readRDD.collect()); - } - - @Test - public void hadoopFile() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.map(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); - - JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); - } - - @Test - public void hadoopFileCompressed() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.map(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, - DefaultCodec.class); - - JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); - } - - @Test - public void zip() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.map(new DoubleFunction() { - @Override - public Double call(Integer x) { - return 1.0 * x; - } - }); - JavaPairRDD zipped = rdd.zip(doubles); - zipped.count(); - } - - @Test - public void zipPartitions() { - JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); - JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); - FlatMapFunction2, Iterator, Integer> sizesFn = - new FlatMapFunction2, Iterator, Integer>() { - @Override - public Iterable call(Iterator i, Iterator s) { - int sizeI = 0; - int sizeS = 0; - while (i.hasNext()) { - sizeI += 1; - i.next(); - } - while (s.hasNext()) { - sizeS += 1; - s.next(); - } - return Arrays.asList(sizeI, sizeS); - } - }; - - JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); - Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); - } - - @Test - public void accumulators() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - - final Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(new VoidFunction() { - public void call(Integer x) { - intAccum.add(x); - } - }); - Assert.assertEquals((Integer) 25, intAccum.value()); - - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); - rdd.foreach(new VoidFunction() { - public void call(Integer x) { - doubleAccum.add((double) x); - } - }); - Assert.assertEquals((Double) 25.0, doubleAccum.value()); - - // Try a custom accumulator type - AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { - public Float addInPlace(Float r, Float t) { - return r + t; - } - - public Float addAccumulator(Float r, Float t) { - return r + t; - } - - public Float zero(Float initialValue) { - return 0.0f; - } - }; - - final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); - rdd.foreach(new VoidFunction() { - public void call(Integer x) { - floatAccum.add((float) x); - } - }); - Assert.assertEquals((Float) 25.0f, floatAccum.value()); - - // Test the setValue method - floatAccum.setValue(5.0f); - Assert.assertEquals((Float) 5.0f, floatAccum.value()); - } - - @Test - public void keyBy() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(new Function() { - public String call(Integer t) throws Exception { - return t.toString(); - } - }).collect(); - Assert.assertEquals(new Tuple2("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2("2", 2), s.get(1)); - } - - @Test - public void checkpointAndComputation() { - File tempDir = Files.createTempDir(); - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - sc.setCheckpointDir(tempDir.getAbsolutePath(), true); - Assert.assertEquals(false, rdd.isCheckpointed()); - rdd.checkpoint(); - rdd.count(); // Forces the DAG to cause a checkpoint - Assert.assertEquals(true, rdd.isCheckpointed()); - Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); - } - - @Test - public void checkpointAndRestore() { - File tempDir = Files.createTempDir(); - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - sc.setCheckpointDir(tempDir.getAbsolutePath(), true); - Assert.assertEquals(false, rdd.isCheckpointed()); - rdd.checkpoint(); - rdd.count(); // Forces the DAG to cause a checkpoint - Assert.assertEquals(true, rdd.isCheckpointed()); - - Assert.assertTrue(rdd.getCheckpointFile().isPresent()); - JavaRDD recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); - Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); - } - - @Test - public void mapOnPairRDD() { - JavaRDD rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); - JavaPairRDD rdd2 = rdd1.map(new PairFunction() { - @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i % 2); - } - }); - JavaPairRDD rdd3 = rdd2.map( - new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 in) throws Exception { - return new Tuple2(in._2(), in._1()); - } - }); - Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(0, 2), - new Tuple2(1, 3), - new Tuple2(0, 4)), rdd3.collect()); - - } -} diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala deleted file mode 100644 index 7568a0bf65..0000000000 --- a/core/src/test/scala/spark/KryoSerializerSuite.scala +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.mutable - -import org.scalatest.FunSuite -import com.esotericsoftware.kryo._ - -import KryoTest._ - -class KryoSerializerSuite extends FunSuite with SharedSparkContext { - test("basic types") { - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - check(1) - check(1L) - check(1.0f) - check(1.0) - check(1.toByte) - check(1.toShort) - check("") - check("hello") - check(Integer.MAX_VALUE) - check(Integer.MIN_VALUE) - check(java.lang.Long.MAX_VALUE) - check(java.lang.Long.MIN_VALUE) - check[String](null) - check(Array(1, 2, 3)) - check(Array(1L, 2L, 3L)) - check(Array(1.0, 2.0, 3.0)) - check(Array(1.0f, 2.9f, 3.9f)) - check(Array("aaa", "bbb", "ccc")) - check(Array("aaa", "bbb", null)) - check(Array(true, false, true)) - check(Array('a', 'b', 'c')) - check(Array[Int]()) - check(Array(Array("1", "2"), Array("1", "2", "3", "4"))) - } - - test("pairs") { - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - check((1, 1)) - check((1, 1L)) - check((1L, 1)) - check((1L, 1L)) - check((1.0, 1)) - check((1, 1.0)) - check((1.0, 1.0)) - check((1.0, 1L)) - check((1L, 1.0)) - check((1.0, 1L)) - check(("x", 1)) - check(("x", 1.0)) - check(("x", 1L)) - check((1, "x")) - check((1.0, "x")) - check((1L, "x")) - check(("x", "x")) - } - - test("Scala data structures") { - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - check(List[Int]()) - check(List[Int](1, 2, 3)) - check(List[String]()) - check(List[String]("x", "y", "z")) - check(None) - check(Some(1)) - check(Some("hi")) - check(mutable.ArrayBuffer(1, 2, 3)) - check(mutable.ArrayBuffer("1", "2", "3")) - check(mutable.Map()) - check(mutable.Map(1 -> "one", 2 -> "two")) - check(mutable.Map("one" -> 1, "two" -> 2)) - check(mutable.HashMap(1 -> "one", 2 -> "two")) - check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) - check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) - } - - test("custom registrator") { - import KryoTest._ - System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) - - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - - check(CaseClass(17, "hello")) - - val c1 = new ClassWithNoArgConstructor - c1.x = 32 - check(c1) - - val c2 = new ClassWithoutNoArgConstructor(47) - check(c2) - - val hashMap = new java.util.HashMap[String, String] - hashMap.put("foo", "bar") - check(hashMap) - - System.clearProperty("spark.kryo.registrator") - } - - test("kryo with collect") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) - assert(control === result.toSeq) - } - - test("kryo with parallelize") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect() - assert (control === result.toSeq) - } - - test("kryo with parallelize for specialized tuples") { - assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).count === 3) - } - - test("kryo with parallelize for primitive arrays") { - assert (sc.parallelize( Array(1, 2, 3) ).count === 3) - } - - test("kryo with collect for specialized tuples") { - assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11)) - } - - test("kryo with reduce") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) - .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x - assert(control.sum === result) - } - - // TODO: this still doesn't work - ignore("kryo with fold") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) - .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x - assert(10 + control.sum === result) - } - - override def beforeAll() { - System.setProperty("spark.serializer", "spark.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) - super.beforeAll() - } - - override def afterAll() { - super.afterAll() - System.clearProperty("spark.kryo.registrator") - System.clearProperty("spark.serializer") - } -} - -object KryoTest { - case class CaseClass(i: Int, s: String) {} - - class ClassWithNoArgConstructor { - var x: Int = 0 - override def equals(other: Any) = other match { - case c: ClassWithNoArgConstructor => x == c.x - case _ => false - } - } - - class ClassWithoutNoArgConstructor(val x: Int) { - override def equals(other: Any) = other match { - case c: ClassWithoutNoArgConstructor => x == c.x - case _ => false - } - } - - class MyRegistrator extends KryoRegistrator { - override def registerClasses(k: Kryo) { - k.register(classOf[CaseClass]) - k.register(classOf[ClassWithNoArgConstructor]) - k.register(classOf[ClassWithoutNoArgConstructor]) - k.register(classOf[java.util.HashMap[_, _]]) - } - } -} diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala deleted file mode 100644 index ddc212d290..0000000000 --- a/core/src/test/scala/spark/LocalSparkContext.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterEach -import org.scalatest.BeforeAndAfterAll - -import org.jboss.netty.logging.InternalLoggerFactory -import org.jboss.netty.logging.Slf4JLoggerFactory - -/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ -trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => - - @transient var sc: SparkContext = _ - - override def beforeAll() { - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); - super.beforeAll() - } - - override def afterEach() { - resetSparkContext() - super.afterEach() - } - - def resetSparkContext() = { - if (sc != null) { - LocalSparkContext.stop(sc) - sc = null - } - } - -} - -object LocalSparkContext { - def stop(sc: SparkContext) { - sc.stop() - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - - /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ - def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { - try { - f(sc) - } finally { - stop(sc) - } - } - -} diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala deleted file mode 100644 index c21f3331d0..0000000000 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite - -import akka.actor._ -import spark.scheduler.MapStatus -import spark.storage.BlockManagerId -import spark.util.AkkaUtils - -class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { - - test("compressSize") { - assert(MapOutputTracker.compressSize(0L) === 0) - assert(MapOutputTracker.compressSize(1L) === 1) - assert(MapOutputTracker.compressSize(2L) === 8) - assert(MapOutputTracker.compressSize(10L) === 25) - assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145) - assert((MapOutputTracker.compressSize(1000000000L) & 0xFF) === 218) - // This last size is bigger than we can encode in a byte, so check that we just return 255 - assert((MapOutputTracker.compressSize(1000000000000000000L) & 0xFF) === 255) - } - - test("decompressSize") { - assert(MapOutputTracker.decompressSize(0) === 0) - for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) { - val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size)) - assert(size2 >= 0.99 * size && size2 <= 1.11 * size, - "size " + size + " decompressed to " + size2 + ", which is out of range") - } - } - - test("master start and stop") { - val actorSystem = ActorSystem("test") - val tracker = new MapOutputTracker() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) - tracker.stop() - } - - test("master register and fetch") { - val actorSystem = ActorSystem("test") - val tracker = new MapOutputTracker() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) - tracker.registerShuffle(10, 2) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val compressedSize10000 = MapOutputTracker.compressSize(10000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), - Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), - Array(compressedSize10000, compressedSize1000))) - val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000), - (BlockManagerId("b", "hostB", 1000, 0), size10000))) - tracker.stop() - } - - test("master register and unregister and fetch") { - val actorSystem = ActorSystem("test") - val tracker = new MapOutputTracker() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) - tracker.registerShuffle(10, 2) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val compressedSize10000 = MapOutputTracker.compressSize(10000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) - - // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) - - // The remaining reduce task might try to grab the output despite the shuffle failure; - // this should cause it to fail, and the scheduler will ignore the failure due to the - // stage already being aborted. - intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } - } - - test("remote fetch") { - val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0) - System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext - System.setProperty("spark.hostPort", hostname + ":" + boundPort) - - val masterTracker = new MapOutputTracker() - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker") - - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) - val slaveTracker = new MapOutputTracker() - slaveTracker.trackerActor = slaveSystem.actorFor( - "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker") - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) - - masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - - // failure should be cached - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - } -} diff --git a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala deleted file mode 100644 index 328b3b5497..0000000000 --- a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala +++ /dev/null @@ -1,299 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet - -import org.scalatest.FunSuite - -import com.google.common.io.Files -import spark.SparkContext._ - - -class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { - test("groupByKey") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with duplicates") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with negative key hash codes") { - val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesForMinus1 = groups.find(_._1 == -1).get._2 - assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with many output partitions") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) - val groups = pairs.groupByKey(10).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("reduceByKey") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("reduceByKey with collectAsMap") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collectAsMap() - assert(sums.size === 2) - assert(sums(1) === 7) - assert(sums(2) === 1) - } - - test("reduceByKey with many output partitons") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_, 10).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("reduceByKey with partitioner") { - val p = new Partitioner() { - def numPartitions = 2 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) - val sums = pairs.reduceByKey(_+_) - assert(sums.collect().toSet === Set((1, 4), (0, 1))) - assert(sums.partitioner === Some(p)) - // count the dependencies to make sure there is only 1 ShuffledRDD - val deps = new HashSet[RDD[_]]() - def visit(r: RDD[_]) { - for (dep <- r.dependencies) { - deps += dep.rdd - visit(dep.rdd) - } - } - visit(sums) - assert(deps.size === 2) // ShuffledRDD, ParallelCollection - } - - test("join") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (2, 'x')), - (2, (1, 'y')), - (2, (1, 'z')) - )) - } - - test("join all-to-all") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) - val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 6) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (1, 'y')), - (1, (2, 'x')), - (1, (2, 'y')), - (1, (3, 'x')), - (1, (3, 'y')) - )) - } - - test("leftOuterJoin") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.leftOuterJoin(rdd2).collect() - assert(joined.size === 5) - assert(joined.toSet === Set( - (1, (1, Some('x'))), - (1, (2, Some('x'))), - (2, (1, Some('y'))), - (2, (1, Some('z'))), - (3, (1, None)) - )) - } - - test("rightOuterJoin") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.rightOuterJoin(rdd2).collect() - assert(joined.size === 5) - assert(joined.toSet === Set( - (1, (Some(1), 'x')), - (1, (Some(2), 'x')), - (2, (Some(1), 'y')), - (2, (Some(1), 'z')), - (4, (None, 'w')) - )) - } - - test("join with no matches") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 0) - } - - test("join with many output partitions") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.join(rdd2, 10).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (2, 'x')), - (2, (1, 'y')), - (2, (1, 'z')) - )) - } - - test("groupWith") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.groupWith(rdd2).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), - (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), - (3, (ArrayBuffer(1), ArrayBuffer())), - (4, (ArrayBuffer(), ArrayBuffer('w'))) - )) - } - - test("zero-partition RDD") { - val emptyDir = Files.createTempDir() - val file = sc.textFile(emptyDir.getAbsolutePath) - assert(file.partitions.size == 0) - assert(file.collect().toList === Nil) - // Test that a shuffle on the file works, because this used to be a bug - assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) - } - - test("keys and values") { - val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) - assert(rdd.keys.collect().toList === List(1, 2)) - assert(rdd.values.collect().toList === List("a", "b")) - } - - test("default partitioner uses partition size") { - // specify 2000 partitions - val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) - // do a map, which loses the partitioner - val b = a.map(a => (a, (a * 2).toString)) - // then a group by, and see we didn't revert to 2 partitions - val c = b.groupByKey() - assert(c.partitions.size === 2000) - } - - test("default partitioner uses largest partitioner") { - val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) - val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) - val c = a.join(b) - assert(c.partitions.size === 2000) - } - - test("subtract") { - val a = sc.parallelize(Array(1, 2, 3), 2) - val b = sc.parallelize(Array(2, 3, 4), 4) - val c = a.subtract(b) - assert(c.collect().toSet === Set(1)) - assert(c.partitions.size === a.partitions.size) - } - - test("subtract with narrow dependency") { - // use a deterministic partitioner - val p = new Partitioner() { - def numPartitions = 5 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - // partitionBy so we have a narrow dependency - val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency - val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) - val c = a.subtract(b) - assert(c.collect().toSet === Set((1, "a"), (3, "c"))) - // Ideally we could keep the original partitioner... - assert(c.partitioner === None) - } - - test("subtractByKey") { - val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) - val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) - val c = a.subtractByKey(b) - assert(c.collect().toSet === Set((1, "a"), (1, "a"))) - assert(c.partitions.size === a.partitions.size) - } - - test("subtractByKey with narrow dependency") { - // use a deterministic partitioner - val p = new Partitioner() { - def numPartitions = 5 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - // partitionBy so we have a narrow dependency - val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency - val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) - val c = a.subtractByKey(b) - assert(c.collect().toSet === Set((1, "a"), (1, "a"))) - assert(c.partitioner.get === p) - } - - test("foldByKey") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.foldByKey(0)(_+_).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("foldByKey with mutable result type") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() - // Fold the values using in-place mutation - val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() - assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) - // Check that the mutable objects in the original RDD were not changed - assert(bufs.collect().toSet === Set( - (1, ArrayBuffer(1)), - (1, ArrayBuffer(2)), - (1, ArrayBuffer(3)), - (1, ArrayBuffer(1)), - (2, ArrayBuffer(1)))) - } -} diff --git a/core/src/test/scala/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/spark/PartitionPruningRDDSuite.scala deleted file mode 100644 index 88352b639f..0000000000 --- a/core/src/test/scala/spark/PartitionPruningRDDSuite.scala +++ /dev/null @@ -1,28 +0,0 @@ -package spark - -import org.scalatest.FunSuite -import spark.SparkContext._ -import spark.rdd.PartitionPruningRDD - - -class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { - - test("Pruned Partitions inherit locality prefs correctly") { - class TestPartition(i: Int) extends Partition { - def index = i - } - val rdd = new RDD[Int](sc, Nil) { - override protected def getPartitions = { - Array[Partition]( - new TestPartition(1), - new TestPartition(2), - new TestPartition(3)) - } - def compute(split: Partition, context: TaskContext) = {Iterator()} - } - val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) - val p = prunedRDD.partitions(0) - assert(p.index == 2) - assert(prunedRDD.partitions.length == 1) - } -} diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala deleted file mode 100644 index b1e0b2b4d0..0000000000 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import scala.collection.mutable.ArrayBuffer -import SparkContext._ -import spark.util.StatCounter -import scala.math.abs - -class PartitioningSuite extends FunSuite with SharedSparkContext { - - test("HashPartitioner equality") { - val p2 = new HashPartitioner(2) - val p4 = new HashPartitioner(4) - val anotherP4 = new HashPartitioner(4) - assert(p2 === p2) - assert(p4 === p4) - assert(p2 != p4) - assert(p4 != p2) - assert(p4 === anotherP4) - assert(anotherP4 === p4) - } - - test("RangePartitioner equality") { - // Make an RDD where all the elements are the same so that the partition range bounds - // are deterministically all the same. - val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x)) - - val p2 = new RangePartitioner(2, rdd) - val p4 = new RangePartitioner(4, rdd) - val anotherP4 = new RangePartitioner(4, rdd) - val descendingP2 = new RangePartitioner(2, rdd, false) - val descendingP4 = new RangePartitioner(4, rdd, false) - - assert(p2 === p2) - assert(p4 === p4) - assert(p2 != p4) - assert(p4 != p2) - assert(p4 === anotherP4) - assert(anotherP4 === p4) - assert(descendingP2 === descendingP2) - assert(descendingP4 === descendingP4) - assert(descendingP2 != descendingP4) - assert(descendingP4 != descendingP2) - assert(p2 != descendingP2) - assert(p4 != descendingP4) - assert(descendingP2 != p2) - assert(descendingP4 != p4) - } - - test("HashPartitioner not equal to RangePartitioner") { - val rdd = sc.parallelize(1 to 10).map(x => (x, x)) - val rangeP2 = new RangePartitioner(2, rdd) - val hashP2 = new HashPartitioner(2) - assert(rangeP2 === rangeP2) - assert(hashP2 === hashP2) - assert(hashP2 != rangeP2) - assert(rangeP2 != hashP2) - } - - test("partitioner preservation") { - val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x)) - - val grouped2 = rdd.groupByKey(2) - val grouped4 = rdd.groupByKey(4) - val reduced2 = rdd.reduceByKey(_ + _, 2) - val reduced4 = rdd.reduceByKey(_ + _, 4) - - assert(rdd.partitioner === None) - - assert(grouped2.partitioner === Some(new HashPartitioner(2))) - assert(grouped4.partitioner === Some(new HashPartitioner(4))) - assert(reduced2.partitioner === Some(new HashPartitioner(2))) - assert(reduced4.partitioner === Some(new HashPartitioner(4))) - - assert(grouped2.groupByKey().partitioner === grouped2.partitioner) - assert(grouped2.groupByKey(3).partitioner != grouped2.partitioner) - assert(grouped2.groupByKey(2).partitioner === grouped2.partitioner) - assert(grouped4.groupByKey().partitioner === grouped4.partitioner) - assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner) - assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner) - - assert(grouped2.join(grouped4).partitioner === grouped4.partitioner) - assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner) - assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner) - assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner) - - assert(grouped2.join(reduced2).partitioner === grouped2.partitioner) - assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) - assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) - assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) - - assert(grouped2.map(_ => 1).partitioner === None) - assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner) - assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner) - assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner) - } - - test("partitioning Java arrays should fail") { - val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x)) - val arrPairs: RDD[(Array[Int], Int)] = - sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) - - assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) - // We can't catch all usages of arrays, since they might occur inside other collections: - //assert(fails { arrPairs.distinct() }) - assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) - } - - test("zero-length partitions should be correctly handled") { - // Create RDD with some consecutive empty partitions (including the "first" one) - val rdd: RDD[Double] = sc - .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) - .filter(_ >= 0.0) - - // Run the partitions, including the consecutive empty ones, through StatCounter - val stats: StatCounter = rdd.stats(); - assert(abs(6.0 - stats.sum) < 0.01); - assert(abs(6.0/2 - rdd.mean) < 0.01); - assert(abs(1.0 - rdd.variance) < 0.01); - assert(abs(1.0 - rdd.stdev) < 0.01); - - // Add other tests here for classes that should be able to handle empty partitions correctly - } -} diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala deleted file mode 100644 index 35c04710a3..0000000000 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import SparkContext._ - -class PipedRDDSuite extends FunSuite with SharedSparkContext { - - test("basic pipe") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - - val piped = nums.pipe(Seq("cat")) - - val c = piped.collect() - assert(c.size === 4) - assert(c(0) === "1") - assert(c(1) === "2") - assert(c(2) === "3") - assert(c(3) === "4") - } - - test("advanced pipe") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val bl = sc.broadcast(List("0")) - - val piped = nums.pipe(Seq("cat"), - Map[String, String](), - (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, - (i:Int, f: String=> Unit) => f(i + "_")) - - val c = piped.collect() - - assert(c.size === 8) - assert(c(0) === "0") - assert(c(1) === "\u0001") - assert(c(2) === "1_") - assert(c(3) === "2_") - assert(c(4) === "0") - assert(c(5) === "\u0001") - assert(c(6) === "3_") - assert(c(7) === "4_") - - val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) - val d = nums1.groupBy(str=>str.split("\t")(0)). - pipe(Seq("cat"), - Map[String, String](), - (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, - (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() - assert(d.size === 8) - assert(d(0) === "0") - assert(d(1) === "\u0001") - assert(d(2) === "b\t2_") - assert(d(3) === "b\t4_") - assert(d(4) === "0") - assert(d(5) === "\u0001") - assert(d(6) === "a\t1_") - assert(d(7) === "a\t3_") - } - - test("pipe with env variable") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) - val c = piped.collect() - assert(c.size === 2) - assert(c(0) === "LALALA") - assert(c(1) === "LALALA") - } - - test("pipe with non-zero exit status") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null")) - intercept[SparkException] { - piped.collect() - } - } - -} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala deleted file mode 100644 index e306952bbd..0000000000 --- a/core/src/test/scala/spark/RDDSuite.scala +++ /dev/null @@ -1,389 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.mutable.HashMap -import org.scalatest.FunSuite -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.time.{Span, Millis} -import spark.SparkContext._ -import spark.rdd._ -import scala.collection.parallel.mutable - -class RDDSuite extends FunSuite with SharedSparkContext { - - test("basic operations") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - assert(nums.collect().toList === List(1, 2, 3, 4)) - val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) - assert(dups.distinct().count() === 4) - assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? - assert(dups.distinct.collect === dups.distinct().collect) - assert(dups.distinct(2).collect === dups.distinct().collect) - assert(nums.reduce(_ + _) === 10) - assert(nums.fold(0)(_ + _) === 10) - assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) - assert(nums.filter(_ > 2).collect().toList === List(3, 4)) - assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) - assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) - assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) - assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) - assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) - val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) - assert(partitionSums.collect().toList === List(3, 7)) - - val partitionSumsWithSplit = nums.mapPartitionsWithSplit { - case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) - } - assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) - - val partitionSumsWithIndex = nums.mapPartitionsWithIndex { - case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) - } - assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7))) - - intercept[UnsupportedOperationException] { - nums.filter(_ > 5).reduce(_ + _) - } - } - - test("SparkContext.union") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) - assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) - assert(sc.union(Seq(nums)).collect().toList === List(1, 2, 3, 4)) - assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) - } - - test("aggregate") { - val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) - type StringMap = HashMap[String, Int] - val emptyMap = new StringMap { - override def default(key: String): Int = 0 - } - val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => { - map(pair._1) += pair._2 - map - } - val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => { - for ((key, value) <- map2) { - map1(key) += value - } - map1 - } - val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps) - assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) - } - - test("basic caching") { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - assert(rdd.collect().toList === List(1, 2, 3, 4)) - assert(rdd.collect().toList === List(1, 2, 3, 4)) - assert(rdd.collect().toList === List(1, 2, 3, 4)) - } - - test("caching with failures") { - val onlySplit = new Partition { override def index: Int = 0 } - var shouldFail = true - val rdd = new RDD[Int](sc, Nil) { - override def getPartitions: Array[Partition] = Array(onlySplit) - override val getDependencies = List[Dependency[_]]() - override def compute(split: Partition, context: TaskContext): Iterator[Int] = { - if (shouldFail) { - throw new Exception("injected failure") - } else { - return Array(1, 2, 3, 4).iterator - } - } - }.cache() - val thrown = intercept[Exception]{ - rdd.collect() - } - assert(thrown.getMessage.contains("injected failure")) - shouldFail = false - assert(rdd.collect().toList === List(1, 2, 3, 4)) - } - - test("empty RDD") { - val empty = new EmptyRDD[Int](sc) - assert(empty.count === 0) - assert(empty.collect().size === 0) - - val thrown = intercept[UnsupportedOperationException]{ - empty.reduce(_+_) - } - assert(thrown.getMessage.contains("empty")) - - val emptyKv = new EmptyRDD[(Int, Int)](sc) - val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x)) - assert(rdd.join(emptyKv).collect().size === 0) - assert(rdd.rightOuterJoin(emptyKv).collect().size === 0) - assert(rdd.leftOuterJoin(emptyKv).collect().size === 2) - assert(rdd.cogroup(emptyKv).collect().size === 2) - assert(rdd.union(emptyKv).collect().size === 2) - } - - test("cogrouped RDDs") { - val data = sc.parallelize(1 to 10, 10) - - val coalesced1 = data.coalesce(2) - assert(coalesced1.collect().toList === (1 to 10).toList) - assert(coalesced1.glom().collect().map(_.toList).toList === - List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) - - // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === - List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === - List(5, 6, 7, 8, 9)) - - val coalesced2 = data.coalesce(3) - assert(coalesced2.collect().toList === (1 to 10).toList) - assert(coalesced2.glom().collect().map(_.toList).toList === - List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10))) - - val coalesced3 = data.coalesce(10) - assert(coalesced3.collect().toList === (1 to 10).toList) - assert(coalesced3.glom().collect().map(_.toList).toList === - (1 to 10).map(x => List(x)).toList) - - // If we try to coalesce into more partitions than the original RDD, it should just - // keep the original number of partitions. - val coalesced4 = data.coalesce(20) - assert(coalesced4.collect().toList === (1 to 10).toList) - assert(coalesced4.glom().collect().map(_.toList).toList === - (1 to 10).map(x => List(x)).toList) - - // we can optionally shuffle to keep the upstream parallel - val coalesced5 = data.coalesce(1, shuffle = true) - assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _, _]] != - null) - } - test("cogrouped RDDs with locality") { - val data3 = sc.makeRDD(List((1,List("a","c")), (2,List("a","b","c")), (3,List("b")))) - val coal3 = data3.coalesce(3) - val list3 = coal3.partitions.map(p => p.asInstanceOf[CoalescedRDDPartition].preferredLocation) - assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") - - // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 - val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)}))) - val coalesced1 = data.coalesce(3) - assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") - - val splits = coalesced1.glom().collect().map(_.toList).toList - assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length) - - assert(splits.forall(_.length >= 1) === true, "Some partitions were empty") - - // If we try to coalesce into more partitions than the original RDD, it should just - // keep the original number of partitions. - val coalesced4 = data.coalesce(20) - val listOfLists = coalesced4.glom().collect().map(_.toList).toList - val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } - assert( sortedList === (1 to 9). - map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") - } - - test("cogrouped RDDs with locality, large scale (10K partitions)") { - // large scale experiment - import collection.mutable - val rnd = scala.util.Random - val partitions = 10000 - val numMachines = 50 - val machines = mutable.ListBuffer[String]() - (1 to numMachines).foreach(machines += "m"+_) - - val blocks = (1 to partitions).map(i => - { (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) } ) - - val data2 = sc.makeRDD(blocks) - val coalesced2 = data2.coalesce(numMachines*2) - - // test that you get over 90% locality in each group - val minLocality = coalesced2.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.)((perc, loc) => math.min(perc,loc)) - assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%") - - // test that the groups are load balanced with 100 +/- 20 elements in each - val maxImbalance = coalesced2.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) - .foldLeft(0)((dev, curr) => math.max(math.abs(100-curr),dev)) - assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance) - - val data3 = sc.makeRDD(blocks).map(i => i*2) // derived RDD to test *current* pref locs - val coalesced3 = data3.coalesce(numMachines*2) - val minLocality2 = coalesced3.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.)((perc, loc) => math.min(perc,loc)) - assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + - (minLocality2*100.).toInt + "%") - } - - test("zipped RDDs") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val zipped = nums.zip(nums.map(_ + 1.0)) - assert(zipped.glom().map(_.toList).collect().toList === - List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) - - intercept[IllegalArgumentException] { - nums.zip(sc.parallelize(1 to 4, 1)).collect() - } - } - - test("partition pruning") { - val data = sc.parallelize(1 to 10, 10) - // Note that split number starts from 0, so > 8 means only 10th partition left. - val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) - assert(prunedRdd.partitions.size === 1) - val prunedData = prunedRdd.collect() - assert(prunedData.size === 1) - assert(prunedData(0) === 10) - } - - test("mapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - val randoms = ones.mapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextDouble * t}.collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(2) === prn42_3) - assert(randoms(5) === prn43_3) - } - - test("flatMapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - val randoms = ones.flatMapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => - val random = prng.nextDouble() - Seq(random * t, random * t * 10)}. - collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(5) === prn42_3 * 10) - assert(randoms(11) === prn43_3 * 10) - } - - test("filterWith") { - import java.util.Random - val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) - val sample = ints.filterWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextInt(3) == 0}. - collect() - val checkSample = { - val prng42 = new Random(42) - val prng43 = new Random(43) - Array(1, 2, 3, 4, 5, 6).filter{i => - if (i < 4) 0 == prng42.nextInt(3) - else 0 == prng43.nextInt(3)} - } - assert(sample.size === checkSample.size) - for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) - } - - test("top with predefined ordering") { - val nums = Array.range(1, 100000) - val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) - val topK = ints.top(5) - assert(topK.size === 5) - assert(topK === nums.reverse.take(5)) - } - - test("top with custom ordering") { - val words = Vector("a", "b", "c", "d") - implicit val ord = implicitly[Ordering[String]].reverse - val rdd = sc.makeRDD(words, 2) - val topK = rdd.top(2) - assert(topK.size === 2) - assert(topK.sorted === Array("b", "a")) - } - - test("takeOrdered with predefined ordering") { - val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - val rdd = sc.makeRDD(nums, 2) - val sortedLowerK = rdd.takeOrdered(5) - assert(sortedLowerK.size === 5) - assert(sortedLowerK === Array(1, 2, 3, 4, 5)) - } - - test("takeOrdered with custom ordering") { - val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - implicit val ord = implicitly[Ordering[Int]].reverse - val rdd = sc.makeRDD(nums, 2) - val sortedTopK = rdd.takeOrdered(5) - assert(sortedTopK.size === 5) - assert(sortedTopK === Array(10, 9, 8, 7, 6)) - assert(sortedTopK === nums.sorted(ord).take(5)) - } - - test("takeSample") { - val data = sc.parallelize(1 to 100, 2) - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 20, seed) - assert(sample.size === 20) // Got exactly 20 elements - assert(sample.toSet.size === 20) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 200, seed) - assert(sample.size === 100) // Got only 100 elements - assert(sample.toSet.size === 100) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 20, seed) - assert(sample.size === 20) // Got exactly 20 elements - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 100, seed) - assert(sample.size === 100) // Got exactly 100 elements - // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 200, seed) - assert(sample.size === 200) // Got exactly 200 elements - // Chance of getting all distinct elements is still quite low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - } - } - - test("runJob on an invalid partition") { - intercept[IllegalArgumentException] { - sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) - } - } -} diff --git a/core/src/test/scala/spark/SharedSparkContext.scala b/core/src/test/scala/spark/SharedSparkContext.scala deleted file mode 100644 index 70c24515be..0000000000 --- a/core/src/test/scala/spark/SharedSparkContext.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterAll - -/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ -trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => - - @transient private var _sc: SparkContext = _ - - def sc: SparkContext = _sc - - override def beforeAll() { - _sc = new SparkContext("local", "test") - super.beforeAll() - } - - override def afterAll() { - if (_sc != null) { - LocalSparkContext.stop(_sc) - _sc = null - } - super.afterAll() - } -} diff --git a/core/src/test/scala/spark/ShuffleNettySuite.scala b/core/src/test/scala/spark/ShuffleNettySuite.scala deleted file mode 100644 index 6bad6c1d13..0000000000 --- a/core/src/test/scala/spark/ShuffleNettySuite.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.BeforeAndAfterAll - - -class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { - - // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. - - override def beforeAll(configMap: Map[String, Any]) { - System.setProperty("spark.shuffle.use.netty", "true") - } - - override def afterAll(configMap: Map[String, Any]) { - System.setProperty("spark.shuffle.use.netty", "false") - } -} diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala deleted file mode 100644 index 8745689c70..0000000000 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers - -import spark.SparkContext._ -import spark.ShuffleSuite.NonJavaSerializableClass -import spark.rdd.{SubtractedRDD, CoGroupedRDD, OrderedRDDFunctions, ShuffledRDD} -import spark.util.MutablePair - - -class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { - test("groupByKey without compression") { - try { - System.setProperty("spark.shuffle.compress", "false") - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) - val groups = pairs.groupByKey(4).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } finally { - System.setProperty("spark.shuffle.compress", "true") - } - } - - test("shuffle non-zero block size") { - sc = new SparkContext("local-cluster[2,1,512]", "test") - val NUM_BLOCKS = 3 - - val a = sc.parallelize(1 to 10, 2) - val b = a.map { x => - (x, new NonJavaSerializableClass(x * 2)) - } - // If the Kryo serializer is not used correctly, the shuffle would fail because the - // default Java serializer cannot handle the non serializable class. - val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(NUM_BLOCKS)).setSerializer(classOf[spark.KryoSerializer].getName) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId - - assert(c.count === 10) - - // All blocks must have non-zero size - (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - assert(statuses.forall(s => s._2 > 0)) - } - } - - test("shuffle serializer") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - val a = sc.parallelize(1 to 10, 2) - val b = a.map { x => - (x, new NonJavaSerializableClass(x * 2)) - } - // If the Kryo serializer is not used correctly, the shuffle would fail because the - // default Java serializer cannot handle the non serializable class. - val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(3)).setSerializer(classOf[spark.KryoSerializer].getName) - assert(c.count === 10) - } - - test("zero sized blocks") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - - // 10 partitions from 4 keys - val NUM_BLOCKS = 10 - val a = sc.parallelize(1 to 4, NUM_BLOCKS) - val b = a.map(x => (x, x*2)) - - // NOTE: The default Java serializer doesn't create zero-sized blocks. - // So, use Kryo - val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) - .setSerializer(classOf[spark.KryoSerializer].getName) - - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId - assert(c.count === 4) - - val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) - } - val nonEmptyBlocks = blockSizes.filter(x => x > 0) - - // We should have at most 4 non-zero sized partitions - assert(nonEmptyBlocks.size <= 4) - } - - test("zero sized blocks without kryo") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - - // 10 partitions from 4 keys - val NUM_BLOCKS = 10 - val a = sc.parallelize(1 to 4, NUM_BLOCKS) - val b = a.map(x => (x, x*2)) - - // NOTE: The default Java serializer should create zero-sized blocks - val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) - - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId - assert(c.count === 4) - - val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) - } - val nonEmptyBlocks = blockSizes.filter(x => x > 0) - - // We should have at most 4 non-zero sized partitions - assert(nonEmptyBlocks.size <= 4) - } - - test("shuffle using mutable pairs") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) - val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) - val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) - val results = new ShuffledRDD[Int, Int, MutablePair[Int, Int]](pairs, new HashPartitioner(2)) - .collect() - - data.foreach { pair => results should contain (pair) } - } - - test("sorting using mutable pairs") { - // This is not in SortingSuite because of the local cluster setup. - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) - val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) - val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) - val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) - .sortByKey().collect() - results(0) should be (p(1, 11)) - results(1) should be (p(2, 22)) - results(2) should be (p(3, 33)) - results(3) should be (p(100, 100)) - } - - test("cogroup using mutable pairs") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) - val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) - val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) - val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) - val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) - val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)).collectAsMap() - - assert(results(1)(0).length === 3) - assert(results(1)(0).contains(1)) - assert(results(1)(0).contains(2)) - assert(results(1)(0).contains(3)) - assert(results(1)(1).length === 2) - assert(results(1)(1).contains("11")) - assert(results(1)(1).contains("12")) - assert(results(2)(0).length === 1) - assert(results(2)(0).contains(1)) - assert(results(2)(1).length === 1) - assert(results(2)(1).contains("22")) - assert(results(3)(0).length === 0) - assert(results(3)(1).contains("3")) - } - - test("subtract mutable pairs") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) - val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) - val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) - val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) - val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) - val results = new SubtractedRDD(pairs1, pairs2, new HashPartitioner(2)).collect() - results should have length (1) - // substracted rdd return results as Tuple2 - results(0) should be ((3, 33)) - } -} - -object ShuffleSuite { - - def mergeCombineException(x: Int, y: Int): Int = { - throw new SparkException("Exception for map-side combine.") - x + y - } - - class NonJavaSerializableClass(val value: Int) -} diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala deleted file mode 100644 index 1ef812dfbd..0000000000 --- a/core/src/test/scala/spark/SizeEstimatorSuite.scala +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfterAll -import org.scalatest.PrivateMethodTester - -class DummyClass1 {} - -class DummyClass2 { - val x: Int = 0 -} - -class DummyClass3 { - val x: Int = 0 - val y: Double = 0.0 -} - -class DummyClass4(val d: DummyClass3) { - val x: Int = 0 -} - -object DummyString { - def apply(str: String) : DummyString = new DummyString(str.toArray) -} -class DummyString(val arr: Array[Char]) { - override val hashCode: Int = 0 - // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f - @transient val hash32: Int = 0 -} - -class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { - - var oldArch: String = _ - var oldOops: String = _ - - override def beforeAll() { - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - oldArch = System.setProperty("os.arch", "amd64") - oldOops = System.setProperty("spark.test.useCompressedOops", "true") - } - - override def afterAll() { - resetOrClear("os.arch", oldArch) - resetOrClear("spark.test.useCompressedOops", oldOops) - } - - test("simple classes") { - assert(SizeEstimator.estimate(new DummyClass1) === 16) - assert(SizeEstimator.estimate(new DummyClass2) === 16) - assert(SizeEstimator.estimate(new DummyClass3) === 24) - assert(SizeEstimator.estimate(new DummyClass4(null)) === 24) - assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48) - } - - // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors - // (Sun vs IBM). Use a DummyString class to make tests deterministic. - test("strings") { - assert(SizeEstimator.estimate(DummyString("")) === 40) - assert(SizeEstimator.estimate(DummyString("a")) === 48) - assert(SizeEstimator.estimate(DummyString("ab")) === 48) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) - } - - test("primitive arrays") { - assert(SizeEstimator.estimate(new Array[Byte](10)) === 32) - assert(SizeEstimator.estimate(new Array[Char](10)) === 40) - assert(SizeEstimator.estimate(new Array[Short](10)) === 40) - assert(SizeEstimator.estimate(new Array[Int](10)) === 56) - assert(SizeEstimator.estimate(new Array[Long](10)) === 96) - assert(SizeEstimator.estimate(new Array[Float](10)) === 56) - assert(SizeEstimator.estimate(new Array[Double](10)) === 96) - assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016) - assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016) - } - - test("object arrays") { - // Arrays containing nulls should just have one pointer per element - assert(SizeEstimator.estimate(new Array[String](10)) === 56) - assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56) - - // For object arrays with non-null elements, each object should take one pointer plus - // however many bytes that class takes. (Note that Array.fill calls the code in its - // second parameter separately for each object, so we get distinct objects.) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296) - assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56) - - // Past size 100, our samples 100 elements, but we should still get the right size. - assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016) - - // If an array contains the *same* element many times, we should only count it once. - val d1 = new DummyClass1 - assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object - assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object - - // Same thing with huge array containing the same element many times. Note that this won't - // return exactly 4032 because it can't tell that *all* the elements will equal the first - // one it samples, but it should be close to that. - - // TODO: If we sample 100 elements, this should always be 4176 ? - val estimatedSize = SizeEstimator.estimate(Array.fill(1000)(d1)) - assert(estimatedSize >= 4000, "Estimated size " + estimatedSize + " should be more than 4000") - assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4100") - } - - test("32-bit arch") { - val arch = System.setProperty("os.arch", "x86") - - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - - assert(SizeEstimator.estimate(DummyString("")) === 40) - assert(SizeEstimator.estimate(DummyString("a")) === 48) - assert(SizeEstimator.estimate(DummyString("ab")) === 48) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) - - resetOrClear("os.arch", arch) - } - - // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors - // (Sun vs IBM). Use a DummyString class to make tests deterministic. - test("64-bit arch with no compressed oops") { - val arch = System.setProperty("os.arch", "amd64") - val oops = System.setProperty("spark.test.useCompressedOops", "false") - - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - - assert(SizeEstimator.estimate(DummyString("")) === 56) - assert(SizeEstimator.estimate(DummyString("a")) === 64) - assert(SizeEstimator.estimate(DummyString("ab")) === 64) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72) - - resetOrClear("os.arch", arch) - resetOrClear("spark.test.useCompressedOops", oops) - } - - def resetOrClear(prop: String, oldValue: String) { - if (oldValue != null) { - System.setProperty(prop, oldValue) - } else { - System.clearProperty(prop) - } - } -} diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala deleted file mode 100644 index b933c4aab8..0000000000 --- a/core/src/test/scala/spark/SortingSuite.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -import org.scalatest.matchers.ShouldMatchers -import SparkContext._ - -class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging { - - test("sortByKey") { - val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) - assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) - } - - test("large array") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - val sorted = pairs.sortByKey() - assert(sorted.partitions.size === 2) - assert(sorted.collect() === pairArr.sortBy(_._1)) - } - - test("large array with one split") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - val sorted = pairs.sortByKey(true, 1) - assert(sorted.partitions.size === 1) - assert(sorted.collect() === pairArr.sortBy(_._1)) - } - - test("large array with many partitions") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - val sorted = pairs.sortByKey(true, 20) - assert(sorted.partitions.size === 20) - assert(sorted.collect() === pairArr.sortBy(_._1)) - } - - test("sort descending") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) - } - - test("sort descending with one split") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 1) - assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) - } - - test("sort descending with many partitions") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) - } - - test("more partitions than elements") { - val rand = new scala.util.Random() - val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 30) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) - } - - test("empty RDD") { - val pairArr = new Array[(Int, Int)](0) - val pairs = sc.parallelize(pairArr, 2) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) - } - - test("partition balancing") { - val pairArr = (1 to 1000).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 4).sortByKey() - assert(sorted.collect() === pairArr.sortBy(_._1)) - val partitions = sorted.collectPartitions() - logInfo("Partition lengths: " + partitions.map(_.length).mkString(", ")) - partitions(0).length should be > 180 - partitions(1).length should be > 180 - partitions(2).length should be > 180 - partitions(3).length should be > 180 - partitions(0).last should be < partitions(1).head - partitions(1).last should be < partitions(2).head - partitions(2).last should be < partitions(3).head - } - - test("partition balancing for descending sort") { - val pairArr = (1 to 1000).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 4).sortByKey(false) - assert(sorted.collect() === pairArr.sortBy(_._1).reverse) - val partitions = sorted.collectPartitions() - logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) - partitions(0).length should be > 180 - partitions(1).length should be > 180 - partitions(2).length should be > 180 - partitions(3).length should be > 180 - partitions(0).last should be > partitions(1).head - partitions(1).last should be > partitions(2).head - partitions(2).last should be > partitions(3).head - } -} - diff --git a/core/src/test/scala/spark/SparkContextInfoSuite.scala b/core/src/test/scala/spark/SparkContextInfoSuite.scala deleted file mode 100644 index 6d50bf5e1b..0000000000 --- a/core/src/test/scala/spark/SparkContextInfoSuite.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import spark.SparkContext._ - -class SparkContextInfoSuite extends FunSuite with LocalSparkContext { - test("getPersistentRDDs only returns RDDs that are marked as cached") { - sc = new SparkContext("local", "test") - assert(sc.getPersistentRDDs.isEmpty === true) - - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) - assert(sc.getPersistentRDDs.isEmpty === true) - - rdd.cache() - assert(sc.getPersistentRDDs.size === 1) - assert(sc.getPersistentRDDs.values.head === rdd) - } - - test("getPersistentRDDs returns an immutable map") { - sc = new SparkContext("local", "test") - val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - - val myRdds = sc.getPersistentRDDs - assert(myRdds.size === 1) - assert(myRdds.values.head === rdd1) - - val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() - - // getPersistentRDDs should have 2 RDDs, but myRdds should not change - assert(sc.getPersistentRDDs.size === 2) - assert(myRdds.size === 1) - } - - test("getRDDStorageInfo only reports on RDDs that actually persist data") { - sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - - assert(sc.getRDDStorageInfo.size === 0) - - rdd.collect() - assert(sc.getRDDStorageInfo.size === 1) - } -} \ No newline at end of file diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala deleted file mode 100644 index f2acd0bd3c..0000000000 --- a/core/src/test/scala/spark/ThreadingSuite.scala +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.util.concurrent.Semaphore -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import SparkContext._ - -/** - * Holds state shared across task threads in some ThreadingSuite tests. - */ -object ThreadingSuiteState { - val runningThreads = new AtomicInteger - val failed = new AtomicBoolean - - def clear() { - runningThreads.set(0) - failed.set(false) - } -} - -class ThreadingSuite extends FunSuite with LocalSparkContext { - - test("accessing SparkContext form a different thread") { - sc = new SparkContext("local", "test") - val nums = sc.parallelize(1 to 10, 2) - val sem = new Semaphore(0) - @volatile var answer1: Int = 0 - @volatile var answer2: Int = 0 - new Thread { - override def run() { - answer1 = nums.reduce(_ + _) - answer2 = nums.first() // This will run "locally" in the current thread - sem.release() - } - }.start() - sem.acquire() - assert(answer1 === 55) - assert(answer2 === 1) - } - - test("accessing SparkContext form multiple threads") { - sc = new SparkContext("local", "test") - val nums = sc.parallelize(1 to 10, 2) - val sem = new Semaphore(0) - @volatile var ok = true - for (i <- 0 until 10) { - new Thread { - override def run() { - val answer1 = nums.reduce(_ + _) - if (answer1 != 55) { - printf("In thread %d: answer1 was %d\n", i, answer1) - ok = false - } - val answer2 = nums.first() // This will run "locally" in the current thread - if (answer2 != 1) { - printf("In thread %d: answer2 was %d\n", i, answer2) - ok = false - } - sem.release() - } - }.start() - } - sem.acquire(10) - if (!ok) { - fail("One or more threads got the wrong answer from an RDD operation") - } - } - - test("accessing multi-threaded SparkContext form multiple threads") { - sc = new SparkContext("local[4]", "test") - val nums = sc.parallelize(1 to 10, 2) - val sem = new Semaphore(0) - @volatile var ok = true - for (i <- 0 until 10) { - new Thread { - override def run() { - val answer1 = nums.reduce(_ + _) - if (answer1 != 55) { - printf("In thread %d: answer1 was %d\n", i, answer1) - ok = false - } - val answer2 = nums.first() // This will run "locally" in the current thread - if (answer2 != 1) { - printf("In thread %d: answer2 was %d\n", i, answer2) - ok = false - } - sem.release() - } - }.start() - } - sem.acquire(10) - if (!ok) { - fail("One or more threads got the wrong answer from an RDD operation") - } - } - - test("parallel job execution") { - // This test launches two jobs with two threads each on a 4-core local cluster. Each thread - // waits until there are 4 threads running at once, to test that both jobs have been launched. - sc = new SparkContext("local[4]", "test") - val nums = sc.parallelize(1 to 2, 2) - val sem = new Semaphore(0) - ThreadingSuiteState.clear() - for (i <- 0 until 2) { - new Thread { - override def run() { - val ans = nums.map(number => { - val running = ThreadingSuiteState.runningThreads - running.getAndIncrement() - val time = System.currentTimeMillis() - while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { - Thread.sleep(100) - } - if (running.get() != 4) { - println("Waited 1 second without seeing runningThreads = 4 (it was " + - running.get() + "); failing test") - ThreadingSuiteState.failed.set(true) - } - number - }).collect() - assert(ans.toList === List(1, 2)) - sem.release() - } - }.start() - } - sem.acquire(2) - if (ThreadingSuiteState.failed.get()) { - fail("One or more threads didn't see runningThreads = 4") - } - } -} diff --git a/core/src/test/scala/spark/UnpersistSuite.scala b/core/src/test/scala/spark/UnpersistSuite.scala deleted file mode 100644 index 93977d16f4..0000000000 --- a/core/src/test/scala/spark/UnpersistSuite.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.time.{Span, Millis} -import spark.SparkContext._ - -class UnpersistSuite extends FunSuite with LocalSparkContext { - test("unpersist RDD") { - sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - rdd.count - assert(sc.persistentRdds.isEmpty === false) - rdd.unpersist() - assert(sc.persistentRdds.isEmpty === true) - - failAfter(Span(3000, Millis)) { - try { - while (! sc.getRDDStorageInfo.isEmpty) { - Thread.sleep(200) - } - } catch { - case _ => { Thread.sleep(10) } - // Do nothing. We might see exceptions because block manager - // is racing this thread to remove entries from the driver. - } - } - assert(sc.getRDDStorageInfo.isEmpty === true) - } -} diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala deleted file mode 100644 index 98a6c1a1c9..0000000000 --- a/core/src/test/scala/spark/UtilsSuite.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import com.google.common.base.Charsets -import com.google.common.io.Files -import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File} -import org.scalatest.FunSuite -import org.apache.commons.io.FileUtils -import scala.util.Random - -class UtilsSuite extends FunSuite { - - test("bytesToString") { - assert(Utils.bytesToString(10) === "10.0 B") - assert(Utils.bytesToString(1500) === "1500.0 B") - assert(Utils.bytesToString(2000000) === "1953.1 KB") - assert(Utils.bytesToString(2097152) === "2.0 MB") - assert(Utils.bytesToString(2306867) === "2.2 MB") - assert(Utils.bytesToString(5368709120L) === "5.0 GB") - assert(Utils.bytesToString(5L * 1024L * 1024L * 1024L * 1024L) === "5.0 TB") - } - - test("copyStream") { - //input array initialization - val bytes = Array.ofDim[Byte](9000) - Random.nextBytes(bytes) - - val os = new ByteArrayOutputStream() - Utils.copyStream(new ByteArrayInputStream(bytes), os) - - assert(os.toByteArray.toList.equals(bytes.toList)) - } - - test("memoryStringToMb") { - assert(Utils.memoryStringToMb("1") === 0) - assert(Utils.memoryStringToMb("1048575") === 0) - assert(Utils.memoryStringToMb("3145728") === 3) - - assert(Utils.memoryStringToMb("1024k") === 1) - assert(Utils.memoryStringToMb("5000k") === 4) - assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K")) - - assert(Utils.memoryStringToMb("1024m") === 1024) - assert(Utils.memoryStringToMb("5000m") === 5000) - assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M")) - - assert(Utils.memoryStringToMb("2g") === 2048) - assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G")) - - assert(Utils.memoryStringToMb("2t") === 2097152) - assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T")) - } - - test("splitCommandString") { - assert(Utils.splitCommandString("") === Seq()) - assert(Utils.splitCommandString("a") === Seq("a")) - assert(Utils.splitCommandString("aaa") === Seq("aaa")) - assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c")) - assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c")) - assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c")) - assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d")) - assert(Utils.splitCommandString("'b c'") === Seq("b c")) - assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c")) - assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d")) - assert(Utils.splitCommandString("\"b c\"") === Seq("b c")) - assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e")) - assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d")) - assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c")) - assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c")) - assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c")) - assert(Utils.splitCommandString("'a'b") === Seq("ab")) - assert(Utils.splitCommandString("'a''b'") === Seq("ab")) - assert(Utils.splitCommandString("\"a\"b") === Seq("ab")) - assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab")) - assert(Utils.splitCommandString("''") === Seq("")) - assert(Utils.splitCommandString("\"\"") === Seq("")) - } - - test("string formatting of time durations") { - val second = 1000 - val minute = second * 60 - val hour = minute * 60 - def str = Utils.msDurationToString(_) - - assert(str(123) === "123 ms") - assert(str(second) === "1.0 s") - assert(str(second + 462) === "1.5 s") - assert(str(hour) === "1.00 h") - assert(str(minute) === "1.0 m") - assert(str(minute + 4 * second + 34) === "1.1 m") - assert(str(10 * hour + minute + 4 * second) === "10.02 h") - assert(str(10 * hour + 59 * minute + 59 * second + 999) === "11.00 h") - } - - test("reading offset bytes of a file") { - val tmpDir2 = Files.createTempDir() - val f1Path = tmpDir2 + "/f1" - val f1 = new FileOutputStream(f1Path) - f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(Charsets.UTF_8)) - f1.close() - - // Read first few bytes - assert(Utils.offsetBytes(f1Path, 0, 5) === "1\n2\n3") - - // Read some middle bytes - assert(Utils.offsetBytes(f1Path, 4, 11) === "3\n4\n5\n6") - - // Read last few bytes - assert(Utils.offsetBytes(f1Path, 12, 18) === "7\n8\n9\n") - - // Read some nonexistent bytes in the beginning - assert(Utils.offsetBytes(f1Path, -5, 5) === "1\n2\n3") - - // Read some nonexistent bytes at the end - assert(Utils.offsetBytes(f1Path, 12, 22) === "7\n8\n9\n") - - // Read some nonexistent bytes on both ends - assert(Utils.offsetBytes(f1Path, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") - - FileUtils.deleteDirectory(tmpDir2) - } -} - diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala deleted file mode 100644 index bb5d379273..0000000000 --- a/core/src/test/scala/spark/ZippedPartitionsSuite.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.immutable.NumericRange - -import org.scalatest.FunSuite -import org.scalatest.prop.Checkers -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ - -import SparkContext._ - - -object ZippedPartitionsSuite { - def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { - Iterator(i.toArray.size, s.toArray.size, d.toArray.size) - } -} - -class ZippedPartitionsSuite extends FunSuite with SharedSparkContext { - test("print sizes") { - val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) - val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) - val data3 = sc.makeRDD(Array(1.0, 2.0), 2) - - val zippedRDD = data1.zipPartitions(data2, data3)(ZippedPartitionsSuite.procZippedData) - - val obtainedSizes = zippedRDD.collect() - val expectedSizes = Array(2, 3, 1, 2, 3, 1) - assert(obtainedSizes.size == 6) - assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2)) - } -} diff --git a/core/src/test/scala/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/spark/io/CompressionCodecSuite.scala deleted file mode 100644 index 1ba82fe2b9..0000000000 --- a/core/src/test/scala/spark/io/CompressionCodecSuite.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.io - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - -import org.scalatest.FunSuite - - -class CompressionCodecSuite extends FunSuite { - - def testCodec(codec: CompressionCodec) { - // Write 1000 integers to the output stream, compressed. - val outputStream = new ByteArrayOutputStream() - val out = codec.compressedOutputStream(outputStream) - for (i <- 1 until 1000) { - out.write(i % 256) - } - out.close() - - // Read the 1000 integers back. - val inputStream = new ByteArrayInputStream(outputStream.toByteArray) - val in = codec.compressedInputStream(inputStream) - for (i <- 1 until 1000) { - assert(in.read() === i % 256) - } - in.close() - } - - test("default compression codec") { - val codec = CompressionCodec.createCodec() - assert(codec.getClass === classOf[SnappyCompressionCodec]) - testCodec(codec) - } - - test("lzf compression codec") { - val codec = CompressionCodec.createCodec(classOf[LZFCompressionCodec].getName) - assert(codec.getClass === classOf[LZFCompressionCodec]) - testCodec(codec) - } - - test("snappy compression codec") { - val codec = CompressionCodec.createCodec(classOf[SnappyCompressionCodec].getName) - assert(codec.getClass === classOf[SnappyCompressionCodec]) - testCodec(codec) - } -} diff --git a/core/src/test/scala/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/spark/metrics/MetricsConfigSuite.scala deleted file mode 100644 index b0213b62d9..0000000000 --- a/core/src/test/scala/spark/metrics/MetricsConfigSuite.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics - -import org.scalatest.{BeforeAndAfter, FunSuite} - -class MetricsConfigSuite extends FunSuite with BeforeAndAfter { - var filePath: String = _ - - before { - filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile() - } - - test("MetricsConfig with default properties") { - val conf = new MetricsConfig(Option("dummy-file")) - conf.initialize() - - assert(conf.properties.size() === 5) - assert(conf.properties.getProperty("test-for-dummy") === null) - - val property = conf.getInstance("random") - assert(property.size() === 3) - assert(property.getProperty("sink.servlet.class") === "spark.metrics.sink.MetricsServlet") - assert(property.getProperty("sink.servlet.uri") === "/metrics/json") - assert(property.getProperty("sink.servlet.sample") === "false") - } - - test("MetricsConfig with properties set") { - val conf = new MetricsConfig(Option(filePath)) - conf.initialize() - - val masterProp = conf.getInstance("master") - assert(masterProp.size() === 6) - assert(masterProp.getProperty("sink.console.period") === "20") - assert(masterProp.getProperty("sink.console.unit") === "minutes") - assert(masterProp.getProperty("source.jvm.class") === "spark.metrics.source.JvmSource") - assert(masterProp.getProperty("sink.servlet.class") === "spark.metrics.sink.MetricsServlet") - assert(masterProp.getProperty("sink.servlet.uri") === "/metrics/master/json") - assert(masterProp.getProperty("sink.servlet.sample") === "false") - - val workerProp = conf.getInstance("worker") - assert(workerProp.size() === 6) - assert(workerProp.getProperty("sink.console.period") === "10") - assert(workerProp.getProperty("sink.console.unit") === "seconds") - assert(workerProp.getProperty("source.jvm.class") === "spark.metrics.source.JvmSource") - assert(workerProp.getProperty("sink.servlet.class") === "spark.metrics.sink.MetricsServlet") - assert(workerProp.getProperty("sink.servlet.uri") === "/metrics/json") - assert(workerProp.getProperty("sink.servlet.sample") === "false") - } - - test("MetricsConfig with subProperties") { - val conf = new MetricsConfig(Option(filePath)) - conf.initialize() - - val propCategories = conf.propertyCategories - assert(propCategories.size === 3) - - val masterProp = conf.getInstance("master") - val sourceProps = conf.subProperties(masterProp, MetricsSystem.SOURCE_REGEX) - assert(sourceProps.size === 1) - assert(sourceProps("jvm").getProperty("class") === "spark.metrics.source.JvmSource") - - val sinkProps = conf.subProperties(masterProp, MetricsSystem.SINK_REGEX) - assert(sinkProps.size === 2) - assert(sinkProps.contains("console")) - assert(sinkProps.contains("servlet")) - - val consoleProps = sinkProps("console") - assert(consoleProps.size() === 2) - - val servletProps = sinkProps("servlet") - assert(servletProps.size() === 3) - } -} diff --git a/core/src/test/scala/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/spark/metrics/MetricsSystemSuite.scala deleted file mode 100644 index dc65ac6994..0000000000 --- a/core/src/test/scala/spark/metrics/MetricsSystemSuite.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics - -import org.scalatest.{BeforeAndAfter, FunSuite} - -class MetricsSystemSuite extends FunSuite with BeforeAndAfter { - var filePath: String = _ - - before { - filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() - System.setProperty("spark.metrics.conf", filePath) - } - - test("MetricsSystem with default config") { - val metricsSystem = MetricsSystem.createMetricsSystem("default") - val sources = metricsSystem.sources - val sinks = metricsSystem.sinks - - assert(sources.length === 0) - assert(sinks.length === 0) - assert(!metricsSystem.getServletHandlers.isEmpty) - } - - test("MetricsSystem with sources add") { - val metricsSystem = MetricsSystem.createMetricsSystem("test") - val sources = metricsSystem.sources - val sinks = metricsSystem.sinks - - assert(sources.length === 0) - assert(sinks.length === 1) - assert(!metricsSystem.getServletHandlers.isEmpty) - - val source = new spark.deploy.master.MasterSource(null) - metricsSystem.registerSource(source) - assert(sources.length === 1) - } -} diff --git a/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala deleted file mode 100644 index dc8ca941c1..0000000000 --- a/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.{ BeforeAndAfter, FunSuite } -import spark.SparkContext._ -import spark.rdd.JdbcRDD -import java.sql._ - -class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - - before { - Class.forName("org.apache.derby.jdbc.EmbeddedDriver") - val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") - try { - val create = conn.createStatement - create.execute(""" - CREATE TABLE FOO( - ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), - DATA INTEGER - )""") - create.close - val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") - (1 to 100).foreach { i => - insert.setInt(1, i * 2) - insert.executeUpdate - } - insert.close - } catch { - case e: SQLException if e.getSQLState == "X0Y32" => - // table exists - } finally { - conn.close - } - } - - test("basic functionality") { - sc = new SparkContext("local", "test") - val rdd = new JdbcRDD( - sc, - () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, - "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", - 1, 100, 3, - (r: ResultSet) => { r.getInt(1) } ).cache - - assert(rdd.count === 100) - assert(rdd.reduce(_+_) === 10100) - } - - after { - try { - DriverManager.getConnection("jdbc:derby:;shutdown=true") - } catch { - case se: SQLException if se.getSQLState == "XJ015" => - // normal shutdown - } - } -} diff --git a/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala deleted file mode 100644 index d1276d541f..0000000000 --- a/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import scala.collection.immutable.NumericRange - -import org.scalatest.FunSuite -import org.scalatest.prop.Checkers -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ - -class ParallelCollectionSplitSuite extends FunSuite with Checkers { - test("one element per slice") { - val data = Array(1, 2, 3) - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices(0).mkString(",") === "1") - assert(slices(1).mkString(",") === "2") - assert(slices(2).mkString(",") === "3") - } - - test("one slice") { - val data = Array(1, 2, 3) - val slices = ParallelCollectionRDD.slice(data, 1) - assert(slices.size === 1) - assert(slices(0).mkString(",") === "1,2,3") - } - - test("equal slices") { - val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9) - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices(0).mkString(",") === "1,2,3") - assert(slices(1).mkString(",") === "4,5,6") - assert(slices(2).mkString(",") === "7,8,9") - } - - test("non-equal slices") { - val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices(0).mkString(",") === "1,2,3") - assert(slices(1).mkString(",") === "4,5,6") - assert(slices(2).mkString(",") === "7,8,9,10") - } - - test("splitting exclusive range") { - val data = 0 until 100 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices(0).mkString(",") === (0 to 32).mkString(",")) - assert(slices(1).mkString(",") === (33 to 65).mkString(",")) - assert(slices(2).mkString(",") === (66 to 99).mkString(",")) - } - - test("splitting inclusive range") { - val data = 0 to 100 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices(0).mkString(",") === (0 to 32).mkString(",")) - assert(slices(1).mkString(",") === (33 to 66).mkString(",")) - assert(slices(2).mkString(",") === (67 to 100).mkString(",")) - } - - test("empty data") { - val data = new Array[Int](0) - val slices = ParallelCollectionRDD.slice(data, 5) - assert(slices.size === 5) - for (slice <- slices) assert(slice.size === 0) - } - - test("zero slices") { - val data = Array(1, 2, 3) - intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) } - } - - test("negative number of slices") { - val data = Array(1, 2, 3) - intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) } - } - - test("exclusive ranges sliced into ranges") { - val data = 1 until 100 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) - assert(slices.forall(_.isInstanceOf[Range])) - } - - test("inclusive ranges sliced into ranges") { - val data = 1 to 100 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) - assert(slices.forall(_.isInstanceOf[Range])) - } - - test("large ranges don't overflow") { - val N = 100 * 1000 * 1000 - val data = 0 until N - val slices = ParallelCollectionRDD.slice(data, 40) - assert(slices.size === 40) - for (i <- 0 until 40) { - assert(slices(i).isInstanceOf[Range]) - val range = slices(i).asInstanceOf[Range] - assert(range.start === i * (N / 40), "slice " + i + " start") - assert(range.end === (i+1) * (N / 40), "slice " + i + " end") - assert(range.step === 1, "slice " + i + " step") - } - } - - test("random array tests") { - val gen = for { - d <- arbitrary[List[Int]] - n <- Gen.choose(1, 100) - } yield (d, n) - val prop = forAll(gen) { - (tuple: (List[Int], Int)) => - val d = tuple._1 - val n = tuple._2 - val slices = ParallelCollectionRDD.slice(d, n) - ("n slices" |: slices.size == n) && - ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) - } - check(prop) - } - - test("random exclusive range tests") { - val gen = for { - a <- Gen.choose(-100, 100) - b <- Gen.choose(-100, 100) - step <- Gen.choose(-5, 5) suchThat (_ != 0) - n <- Gen.choose(1, 100) - } yield (a until b by step, n) - val prop = forAll(gen) { - case (d: Range, n: Int) => - val slices = ParallelCollectionRDD.slice(d, n) - ("n slices" |: slices.size == n) && - ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && - ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) - } - check(prop) - } - - test("random inclusive range tests") { - val gen = for { - a <- Gen.choose(-100, 100) - b <- Gen.choose(-100, 100) - step <- Gen.choose(-5, 5) suchThat (_ != 0) - n <- Gen.choose(1, 100) - } yield (a to b by step, n) - val prop = forAll(gen) { - case (d: Range, n: Int) => - val slices = ParallelCollectionRDD.slice(d, n) - ("n slices" |: slices.size == n) && - ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && - ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) - } - check(prop) - } - - test("exclusive ranges of longs") { - val data = 1L until 100L - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) - assert(slices.forall(_.isInstanceOf[NumericRange[_]])) - } - - test("inclusive ranges of longs") { - val data = 1L to 100L - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) - assert(slices.forall(_.isInstanceOf[NumericRange[_]])) - } - - test("exclusive ranges of doubles") { - val data = 1.0 until 100.0 by 1.0 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) - assert(slices.forall(_.isInstanceOf[NumericRange[_]])) - } - - test("inclusive ranges of doubles") { - val data = 1.0 to 100.0 by 1.0 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) - assert(slices.forall(_.isInstanceOf[NumericRange[_]])) - } -} diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala deleted file mode 100644 index 3b4a0d52fc..0000000000 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ /dev/null @@ -1,421 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import scala.collection.mutable.{Map, HashMap} - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import spark.LocalSparkContext -import spark.MapOutputTracker -import spark.RDD -import spark.SparkContext -import spark.Partition -import spark.TaskContext -import spark.{Dependency, ShuffleDependency, OneToOneDependency} -import spark.{FetchFailed, Success, TaskEndReason} -import spark.storage.{BlockManagerId, BlockManagerMaster} - -import spark.scheduler.cluster.Pool -import spark.scheduler.cluster.SchedulingMode -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - -/** - * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler - * rather than spawning an event loop thread as happens in the real code. They use EasyMock - * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are - * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead - * host notifications are sent). In addition, tests may check for side effects on a non-mocked - * MapOutputTracker instance. - * - * Tests primarily consist of running DAGScheduler#processEvent and - * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet) - * and capturing the resulting TaskSets from the mock TaskScheduler. - */ -class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - - /** Set of TaskSets the DAGScheduler has requested executed. */ - val taskSets = scala.collection.mutable.Buffer[TaskSet]() - val taskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE - override def start() = {} - override def stop() = {} - override def submitTasks(taskSet: TaskSet) = { - // normally done by TaskSetManager - taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) - taskSets += taskSet - } - override def setListener(listener: TaskSchedulerListener) = {} - override def defaultParallelism() = 2 - } - - var mapOutputTracker: MapOutputTracker = null - var scheduler: DAGScheduler = null - - /** - * Set of cache locations to return from our mock BlockManagerMaster. - * Keys are (rdd ID, partition ID). Anything not present will return an empty - * list of cache locations silently. - */ - val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] - // stub out BlockManagerMaster.getLocations to use our cacheLocations - val blockManagerMaster = new BlockManagerMaster(null) { - override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - blockIds.map { name => - val pieces = name.split("_") - if (pieces(0) == "rdd") { - val key = pieces(1).toInt -> pieces(2).toInt - cacheLocations.getOrElse(key, Seq()) - } else { - Seq() - } - }.toSeq - } - override def removeExecutor(execId: String) { - // don't need to propagate to the driver, which we don't have - } - } - - /** The list of results that DAGScheduler has collected. */ - val results = new HashMap[Int, Any]() - var failure: Exception = _ - val listener = new JobListener() { - override def taskSucceeded(index: Int, result: Any) = results.put(index, result) - override def jobFailed(exception: Exception) = { failure = exception } - } - - before { - sc = new SparkContext("local", "DAGSchedulerSuite") - taskSets.clear() - cacheLocations.clear() - results.clear() - mapOutputTracker = new MapOutputTracker() - scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) { - override def runLocally(job: ActiveJob) { - // don't bother with the thread while unit testing - runLocallyWithinThread(job) - } - } - } - - after { - scheduler.stop() - } - - /** - * Type of RDD we use for testing. Note that we should never call the real RDD compute methods. - * This is a pair RDD type so it can always be used in ShuffleDependencies. - */ - type MyRDD = RDD[(Int, Int)] - - /** - * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and - * preferredLocations (if any) that are passed to them. They are deliberately not executable - * so we can test that DAGScheduler does not try to execute RDDs locally. - */ - private def makeRdd( - numPartitions: Int, - dependencies: List[Dependency[_]], - locations: Seq[Seq[String]] = Nil - ): MyRDD = { - val maxPartition = numPartitions - 1 - return new MyRDD(sc, dependencies) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - throw new RuntimeException("should not be reached") - override def getPartitions = (0 to maxPartition).map(i => new Partition { - override def index = i - }).toArray - override def getPreferredLocations(split: Partition): Seq[String] = - if (locations.isDefinedAt(split.index)) - locations(split.index) - else - Nil - override def toString: String = "DAGSchedulerSuiteRDD " + id - } - } - - /** - * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting - * the scheduler not to exit. - * - * After processing the event, submit waiting stages as is done on most iterations of the - * DAGScheduler event loop. - */ - private def runEvent(event: DAGSchedulerEvent) { - assert(!scheduler.processEvent(event)) - scheduler.submitWaitingStages() - } - - /** - * When we submit dummy Jobs, this is the compute function we supply. Except in a local test - * below, we do not expect this function to ever be executed; instead, we will return results - * directly through CompletionEvents. - */ - private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) => - it.next.asInstanceOf[Tuple2[_, _]]._1 - - /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ - private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { - assert(taskSet.tasks.size >= results.size) - for ((result, i) <- results.zipWithIndex) { - if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null)) - } - } - } - - /** Sends the rdd to the scheduler for scheduling. */ - private def submit( - rdd: RDD[_], - partitions: Array[Int], - func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - allowLocal: Boolean = false, - listener: JobListener = listener) { - runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener)) - } - - /** Sends TaskSetFailed to the scheduler. */ - private def failed(taskSet: TaskSet, message: String) { - runEvent(TaskSetFailed(taskSet, message)) - } - - test("zero split job") { - val rdd = makeRdd(0, Nil) - var numResults = 0 - val fakeListener = new JobListener() { - override def taskSucceeded(partition: Int, value: Any) = numResults += 1 - override def jobFailed(exception: Exception) = throw exception - } - submit(rdd, Array(), listener = fakeListener) - assert(numResults === 0) - } - - test("run trivial job") { - val rdd = makeRdd(1, Nil) - submit(rdd, Array(0)) - complete(taskSets(0), List((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("local job") { - val rdd = new MyRDD(sc, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - Array(42 -> 0).iterator - override def getPartitions = Array( new Partition { override def index = 0 } ) - override def getPreferredLocations(split: Partition) = Nil - override def toString = "DAGSchedulerSuite Local RDD" - } - runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) - assert(results === Map(0 -> 42)) - } - - test("run trivial job w/ dependency") { - val baseRdd = makeRdd(1, Nil) - val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) - submit(finalRdd, Array(0)) - complete(taskSets(0), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("cache location preferences w/ dependency") { - val baseRdd = makeRdd(1, Nil) - val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) - cacheLocations(baseRdd.id -> 0) = - Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) - submit(finalRdd, Array(0)) - val taskSet = taskSets(0) - assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) - complete(taskSet, Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("trivial job failure") { - submit(makeRdd(1, Nil), Array(0)) - failed(taskSets(0), "some failure") - assert(failure.getMessage === "Job failed: some failure") - } - - test("run trivial shuffle") { - val shuffleMapRdd = makeRdd(2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) - val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(1, List(shuffleDep)) - submit(reduceRdd, Array(0)) - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) - complete(taskSets(1), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("run trivial shuffle with fetch failure") { - val shuffleMapRdd = makeRdd(2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) - val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(2, List(shuffleDep)) - submit(reduceRdd, Array(0, 1)) - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) - // the 2nd ResultTask failed - complete(taskSets(1), Seq( - (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null))) - // this will get called - // blockManagerMaster.removeExecutor("exec-hostA") - // ask the scheduler to try it again - scheduler.resubmitFailedStages() - // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) - // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) - complete(taskSets(3), Seq((Success, 43))) - assert(results === Map(0 -> 42, 1 -> 43)) - } - - test("ignore late map task completions") { - val shuffleMapRdd = makeRdd(2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) - val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(2, List(shuffleDep)) - submit(reduceRdd, Array(0, 1)) - // pretend we were told hostA went away - val oldEpoch = mapOutputTracker.getEpoch - runEvent(ExecutorLost("exec-hostA")) - val newEpoch = mapOutputTracker.getEpoch - assert(newEpoch > oldEpoch) - val noAccum = Map[Long, Any]() - val taskSet = taskSets(0) - // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) - // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null)) - // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) - // should work because it's a new epoch - taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null)) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) - complete(taskSets(1), Seq((Success, 42), (Success, 43))) - assert(results === Map(0 -> 42, 1 -> 43)) - } - - test("run trivial shuffle with out-of-band failure and retry") { - val shuffleMapRdd = makeRdd(2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) - val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(1, List(shuffleDep)) - submit(reduceRdd, Array(0)) - // blockManagerMaster.removeExecutor("exec-hostA") - // pretend we were told hostA went away - runEvent(ExecutorLost("exec-hostA")) - // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks - // rather than marking it is as failed and waiting. - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) - // have hostC complete the resubmitted task - complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) - complete(taskSets(2), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("recursive shuffle failures") { - val shuffleOneRdd = makeRdd(2, Nil) - val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) - val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) - val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) - val finalRdd = makeRdd(1, List(shuffleDepTwo)) - submit(finalRdd, Array(0)) - // have the first stage complete normally - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) - // have the second stage complete normally - complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostC", 1)))) - // fail the third stage because hostA went down - complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) - // TODO assert this: - // blockManagerMaster.removeExecutor("exec-hostA") - // have DAGScheduler try again - scheduler.resubmitFailedStages() - complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2)))) - complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) - complete(taskSets(5), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("cached post-shuffle") { - val shuffleOneRdd = makeRdd(2, Nil) - val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) - val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) - val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) - val finalRdd = makeRdd(1, List(shuffleDepTwo)) - submit(finalRdd, Array(0)) - cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) - cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) - // complete stage 2 - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) - // complete stage 1 - complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) - // pretend stage 0 failed because hostA went down - complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) - // TODO assert this: - // blockManagerMaster.removeExecutor("exec-hostA") - // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. - scheduler.resubmitFailedStages() - assertLocations(taskSets(3), Seq(Seq("hostD"))) - // allow hostD to recover - complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) - complete(taskSets(4), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - /** - * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. - * Note that this checks only the host and not the executor ID. - */ - private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) { - assert(hosts.size === taskSet.tasks.size) - for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) { - assert(taskLocs.map(_.host) === expectedLocs) - } - } - - private def makeMapStatus(host: String, reduces: Int): MapStatus = - new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) - - private def makeBlockManagerId(host: String): BlockManagerId = - BlockManagerId("exec-" + host, host, 12345, 0) - -} diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala deleted file mode 100644 index bb9e715f95..0000000000 --- a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.util.Properties -import java.util.concurrent.LinkedBlockingQueue -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers -import scala.collection.mutable -import spark._ -import spark.SparkContext._ - - -class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { - - test("inner method") { - sc = new SparkContext("local", "joblogger") - val joblogger = new JobLogger { - def createLogWriterTest(jobID: Int) = createLogWriter(jobID) - def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID) - def getRddNameTest(rdd: RDD[_]) = getRddName(rdd) - def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) - } - type MyRDD = RDD[(Int, Int)] - def makeRdd( - numPartitions: Int, - dependencies: List[Dependency[_]] - ): MyRDD = { - val maxPartition = numPartitions - 1 - return new MyRDD(sc, dependencies) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - throw new RuntimeException("should not be reached") - override def getPartitions = (0 to maxPartition).map(i => new Partition { - override def index = i - }).toArray - } - } - val jobID = 5 - val parentRdd = makeRdd(4, Nil) - val shuffleDep = new ShuffleDependency(parentRdd, null) - val rootRdd = makeRdd(4, List(shuffleDep)) - val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None) - val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None) - - joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null)) - joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) - parentRdd.setName("MyRDD") - joblogger.getRddNameTest(parentRdd) should be ("MyRDD") - joblogger.createLogWriterTest(jobID) - joblogger.getJobIDtoPrintWriter.size should be (1) - joblogger.buildJobDepTest(jobID, rootStage) - joblogger.getJobIDToStages.get(jobID).get.size should be (2) - joblogger.getStageIDToJobID.get(0) should be (Some(jobID)) - joblogger.getStageIDToJobID.get(1) should be (Some(jobID)) - joblogger.closeLogWriterTest(jobID) - joblogger.getStageIDToJobID.size should be (0) - joblogger.getJobIDToStages.size should be (0) - joblogger.getJobIDtoPrintWriter.size should be (0) - } - - test("inner variables") { - sc = new SparkContext("local[4]", "joblogger") - val joblogger = new JobLogger { - override protected def closeLogWriter(jobID: Int) = - getJobIDtoPrintWriter.get(jobID).foreach { fileWriter => - fileWriter.close() - } - } - sc.addSparkListener(joblogger) - val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } - rdd.reduceByKey(_+_).collect() - - joblogger.getLogDir should be ("/tmp/spark") - joblogger.getJobIDtoPrintWriter.size should be (1) - joblogger.getStageIDToJobID.size should be (2) - joblogger.getStageIDToJobID.get(0) should be (Some(0)) - joblogger.getStageIDToJobID.get(1) should be (Some(0)) - joblogger.getJobIDToStages.size should be (1) - } - - - test("interface functions") { - sc = new SparkContext("local[4]", "joblogger") - val joblogger = new JobLogger { - var onTaskEndCount = 0 - var onJobEndCount = 0 - var onJobStartCount = 0 - var onStageCompletedCount = 0 - var onStageSubmittedCount = 0 - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 - override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1 - override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 - override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 - } - sc.addSparkListener(joblogger) - val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } - rdd.reduceByKey(_+_).collect() - - joblogger.onJobStartCount should be (1) - joblogger.onJobEndCount should be (1) - joblogger.onTaskEndCount should be (8) - joblogger.onStageSubmittedCount should be (2) - joblogger.onStageCompletedCount should be (2) - } -} diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala deleted file mode 100644 index 392d67d67b..0000000000 --- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import org.scalatest.FunSuite -import spark.{SparkContext, LocalSparkContext} -import scala.collection.mutable -import org.scalatest.matchers.ShouldMatchers -import spark.SparkContext._ - -/** - * - */ - -class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { - - test("local metrics") { - sc = new SparkContext("local[4]", "test") - val listener = new SaveStageInfo - sc.addSparkListener(listener) - sc.addSparkListener(new StatsReportListener) - //just to make sure some of the tasks take a noticeable amount of time - val w = {i:Int => - if (i == 0) - Thread.sleep(100) - i - } - - val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)} - d.count - listener.stageInfos.size should be (1) - - val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1") - - val d3 = d.map{i => w(i) -> (0 to (i % 5))}.setName("shuffle input 2") - - val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)} - d4.setName("A Cogroup") - - d4.collectAsMap - - listener.stageInfos.size should be (4) - listener.stageInfos.foreach {stageInfo => - //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms - checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration") - checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime") - checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime") - if (stageInfo.stage.rdd.name == d4.name) { - checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime") - } - - stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) => - taskMetrics.resultSize should be > (0l) - if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) { - taskMetrics.shuffleWriteMetrics should be ('defined) - taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l) - } - if (stageInfo.stage.rdd.name == d4.name) { - taskMetrics.shuffleReadMetrics should be ('defined) - val sm = taskMetrics.shuffleReadMetrics.get - sm.totalBlocksFetched should be > (0) - sm.localBlocksFetched should be > (0) - sm.remoteBlocksFetched should be (0) - sm.remoteBytesRead should be (0l) - sm.remoteFetchTime should be (0l) - } - } - } - } - - def checkNonZeroAvg(m: Traversable[Long], msg: String) { - assert(m.sum / m.size.toDouble > 0.0, msg) - } - - def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = { - val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name} - !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty - } - - class SaveStageInfo extends SparkListener { - val stageInfos = mutable.Buffer[StageInfo]() - override def onStageCompleted(stage: StageCompleted) { - stageInfos += stage.stageInfo - } - } - -} diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala deleted file mode 100644 index 95a6eee2fc..0000000000 --- a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -import spark.TaskContext -import spark.RDD -import spark.SparkContext -import spark.Partition -import spark.LocalSparkContext - -class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - - test("Calls executeOnCompleteCallbacks after failure") { - var completed = false - sc = new SparkContext("local", "test") - val rdd = new RDD[String](sc, List()) { - override def getPartitions = Array[Partition](StubPartition(0)) - override def compute(split: Partition, context: TaskContext) = { - context.addOnCompleteCallback(() => completed = true) - sys.error("failed") - } - } - val func = (c: TaskContext, i: Iterator[String]) => i.next - val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) - intercept[RuntimeException] { - task.run(0) - } - assert(completed === true) - } - - case class StubPartition(val index: Int) extends Partition -} diff --git a/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala deleted file mode 100644 index abfdabf5fe..0000000000 --- a/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala +++ /dev/null @@ -1,266 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import spark._ -import spark.scheduler._ -import spark.scheduler.cluster._ -import scala.collection.mutable.ArrayBuffer - -import java.util.Properties - -class FakeTaskSetManager( - initPriority: Int, - initStageId: Int, - initNumTasks: Int, - clusterScheduler: ClusterScheduler, - taskSet: TaskSet) - extends ClusterTaskSetManager(clusterScheduler, taskSet) { - - parent = null - weight = 1 - minShare = 2 - runningTasks = 0 - priority = initPriority - stageId = initStageId - name = "TaskSet_"+stageId - override val numTasks = initNumTasks - tasksFinished = 0 - - override def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def addSchedulable(schedulable: Schedulable) { - } - - override def removeSchedulable(schedulable: Schedulable) { - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def executorLost(executorId: String, host: String): Unit = { - } - - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - if (tasksFinished + runningTasks < numTasks) { - increaseRunningTasks(1) - return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) - } - return None - } - - override def checkSpeculatableTasks(): Boolean = { - return true - } - - def taskFinished() { - decreaseRunningTasks(1) - tasksFinished +=1 - if (tasksFinished == numTasks) { - parent.removeSchedulable(this) - } - } - - def abort() { - decreaseRunningTasks(runningTasks) - parent.removeSchedulable(this) - } -} - -class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { - - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { - new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) - } - - def resourceOffer(rootPool: Pool): Int = { - val taskSetQueue = rootPool.getSortedTaskSetQueue() - /* Just for Test*/ - for (manager <- taskSetQueue) { - logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) - } - for (taskSet <- taskSetQueue) { - taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match { - case Some(task) => - return taskSet.stageId - case None => {} - } - } - -1 - } - - def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { - assert(resourceOffer(rootPool) === expectedTaskSetId) - } - - test("FIFO Scheduler Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) - val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) - schedulableBuilder.buildPools() - - val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet) - val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet) - val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager0, null) - schedulableBuilder.addTaskSetManager(taskSetManager1, null) - schedulableBuilder.addTaskSetManager(taskSetManager2, null) - - checkTaskSetId(rootPool, 0) - resourceOffer(rootPool) - checkTaskSetId(rootPool, 1) - resourceOffer(rootPool) - taskSetManager1.abort() - checkTaskSetId(rootPool, 2) - } - - test("Fair Scheduler Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.fairscheduler.allocation.file", xmlPath) - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val schedulableBuilder = new FairSchedulableBuilder(rootPool) - schedulableBuilder.buildPools() - - assert(rootPool.getSchedulableByName("default") != null) - assert(rootPool.getSchedulableByName("1") != null) - assert(rootPool.getSchedulableByName("2") != null) - assert(rootPool.getSchedulableByName("3") != null) - assert(rootPool.getSchedulableByName("1").minShare === 2) - assert(rootPool.getSchedulableByName("1").weight === 1) - assert(rootPool.getSchedulableByName("2").minShare === 3) - assert(rootPool.getSchedulableByName("2").weight === 1) - assert(rootPool.getSchedulableByName("3").minShare === 2) - assert(rootPool.getSchedulableByName("3").weight === 1) - - val properties1 = new Properties() - properties1.setProperty("spark.scheduler.cluster.fair.pool","1") - val properties2 = new Properties() - properties2.setProperty("spark.scheduler.cluster.fair.pool","2") - - val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet) - val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet) - val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) - schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) - schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) - - val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet) - val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) - schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) - - checkTaskSetId(rootPool, 0) - checkTaskSetId(rootPool, 3) - checkTaskSetId(rootPool, 3) - checkTaskSetId(rootPool, 1) - checkTaskSetId(rootPool, 4) - checkTaskSetId(rootPool, 2) - checkTaskSetId(rootPool, 2) - checkTaskSetId(rootPool, 4) - - taskSetManager12.taskFinished() - assert(rootPool.getSchedulableByName("1").runningTasks === 3) - taskSetManager24.abort() - assert(rootPool.getSchedulableByName("2").runningTasks === 2) - } - - test("Nested Pool Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) - val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) - rootPool.addSchedulable(pool0) - rootPool.addSchedulable(pool1) - - val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) - val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) - pool0.addSchedulable(pool00) - pool0.addSchedulable(pool01) - - val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) - val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) - pool1.addSchedulable(pool10) - pool1.addSchedulable(pool11) - - val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet) - val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet) - pool00.addSchedulable(taskSetManager000) - pool00.addSchedulable(taskSetManager001) - - val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet) - val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet) - pool01.addSchedulable(taskSetManager010) - pool01.addSchedulable(taskSetManager011) - - val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet) - val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet) - pool10.addSchedulable(taskSetManager100) - pool10.addSchedulable(taskSetManager101) - - val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet) - val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet) - pool11.addSchedulable(taskSetManager110) - pool11.addSchedulable(taskSetManager111) - - checkTaskSetId(rootPool, 0) - checkTaskSetId(rootPool, 4) - checkTaskSetId(rootPool, 6) - checkTaskSetId(rootPool, 2) - } -} diff --git a/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala deleted file mode 100644 index 5a0b949ef5..0000000000 --- a/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable - -import org.scalatest.FunSuite - -import spark._ -import spark.scheduler._ -import spark.executor.TaskMetrics -import java.nio.ByteBuffer -import spark.util.FakeClock - -/** - * A mock ClusterScheduler implementation that just remembers information about tasks started and - * feedback received from the TaskSetManagers. Note that it's important to initialize this with - * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost - * to work, and these are required for locality in ClusterTaskSetManager. - */ -class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) - extends ClusterScheduler(sc) -{ - val startedTasks = new ArrayBuffer[Long] - val endedTasks = new mutable.HashMap[Long, TaskEndReason] - val finishedManagers = new ArrayBuffer[TaskSetManager] - - val executors = new mutable.HashMap[String, String] ++ liveExecutors - - listener = new TaskSchedulerListener { - def taskStarted(task: Task[_], taskInfo: TaskInfo) { - startedTasks += taskInfo.index - } - - def taskEnded( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: mutable.Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) - { - endedTasks(taskInfo.index) = reason - } - - def executorGained(execId: String, host: String) {} - - def executorLost(execId: String) {} - - def taskSetFailed(taskSet: TaskSet, reason: String) {} - } - - def removeExecutor(execId: String): Unit = executors -= execId - - override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager - - override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) - - override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) -} - -class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { - import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} - - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - - test("TaskSet with no preferences") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) - val manager = new ClusterTaskSetManager(sched, taskSet) - - // Offer a host with no CPUs - assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) - - // Offer a host with process-local as the constraint; this should work because the TaskSet - // above won't have any locality preferences - val taskOption = manager.resourceOffer("exec1", "host1", 2, TaskLocality.PROCESS_LOCAL) - assert(taskOption.isDefined) - val task = taskOption.get - assert(task.executorId === "exec1") - assert(sched.startedTasks.contains(0)) - - // Re-offer the host -- now we should get no more tasks - assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None) - - // Tell it the task has finished - manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0)) - assert(sched.endedTasks(0) === Success) - assert(sched.finishedManagers.contains(manager)) - } - - test("multiple offers with no preferences") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(3) - val manager = new ClusterTaskSetManager(sched, taskSet) - - // First three offers should all find tasks - for (i <- 0 until 3) { - val taskOption = manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) - assert(taskOption.isDefined) - val task = taskOption.get - assert(task.executorId === "exec1") - } - assert(sched.startedTasks.toSet === Set(0, 1, 2)) - - // Re-offer the host -- now we should get no more tasks - assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) - - // Finish the first two tasks - manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0)) - manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1)) - assert(sched.endedTasks(0) === Success) - assert(sched.endedTasks(1) === Success) - assert(!sched.finishedManagers.contains(manager)) - - // Finish the last task - manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2)) - assert(sched.endedTasks(2) === Success) - assert(sched.finishedManagers.contains(manager)) - } - - test("basic delay scheduling") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = createTaskSet(4, - Seq(TaskLocation("host1", "exec1")), - Seq(TaskLocation("host2", "exec2")), - Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), - Seq() // Last task has no locality prefs - ) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - // First offer host1, exec1: first task should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - - // Offer host1, exec1 again: the last task, which has no prefs, should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 3) - - // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) - - clock.advance(LOCALITY_WAIT) - - // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) - - // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2 - assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL).get.index == 2) - - // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL) === None) - - // Offer host1, exec1 again, at ANY level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - - clock.advance(LOCALITY_WAIT) - - // Offer host1, exec1 again, at ANY level: task 1 should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) - - // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - } - - test("delay scheduling with fallback") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, - ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) - val taskSet = createTaskSet(5, - Seq(TaskLocation("host1")), - Seq(TaskLocation("host2")), - Seq(TaskLocation("host2")), - Seq(TaskLocation("host3")), - Seq(TaskLocation("host2")) - ) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - // First offer host1: first task should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - - // Offer host1 again: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - - clock.advance(LOCALITY_WAIT) - - // Offer host1 again: second task (on host2) should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) - - // Offer host1 again: third task (on host2) should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) - - // Offer host2: fifth task (also on host2) should get chosen - assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 4) - - // Now that we've launched a local task, we should no longer launch the task for host3 - assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) - - clock.advance(LOCALITY_WAIT) - - // After another delay, we can go ahead and launch that task non-locally - assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 3) - } - - test("delay scheduling with failed hosts") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = createTaskSet(3, - Seq(TaskLocation("host1")), - Seq(TaskLocation("host2")), - Seq(TaskLocation("host3")) - ) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - // First offer host1: first task should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - - // Offer host1 again: third task should be chosen immediately because host3 is not up - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) - - // After this, nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - - // Now mark host2 as dead - sched.removeExecutor("exec2") - manager.executorLost("exec2", "host2") - - // Task 1 should immediately be launched on host1 because its original host is gone - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) - - // Now that all tasks have launched, nothing new should be launched anywhere else - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) - } - - /** - * Utility method to create a TaskSet, potentially setting a particular sequence of preferred - * locations for each task (given as varargs) if this sequence is not empty. - */ - def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { - if (prefLocs.size != 0 && prefLocs.size != numTasks) { - throw new IllegalArgumentException("Wrong number of task locations") - } - val tasks = Array.tabulate[Task[_]](numTasks) { i => - new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) - } - new TaskSet(tasks, 0, 0, 0, null) - } - - def createTaskResult(id: Int): ByteBuffer = { - ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics))) - } -} diff --git a/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala deleted file mode 100644 index de9e66be20..0000000000 --- a/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.scheduler.{TaskLocation, Task} - -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) { - override def run(attemptId: Long): Int = 0 - - override def preferredLocations: Seq[TaskLocation] = prefLocs -} diff --git a/core/src/test/scala/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/local/LocalSchedulerSuite.scala deleted file mode 100644 index d28ee47fa3..0000000000 --- a/core/src/test/scala/spark/scheduler/local/LocalSchedulerSuite.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.local - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import spark._ -import spark.scheduler._ -import spark.scheduler.cluster._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.{ConcurrentMap, HashMap} -import java.util.concurrent.Semaphore -import java.util.concurrent.CountDownLatch -import java.util.Properties - -class Lock() { - var finished = false - def jobWait() = { - synchronized { - while(!finished) { - this.wait() - } - } - } - - def jobFinished() = { - synchronized { - finished = true - this.notifyAll() - } - } -} - -object TaskThreadInfo { - val threadToLock = HashMap[Int, Lock]() - val threadToRunning = HashMap[Int, Boolean]() - val threadToStarted = HashMap[Int, CountDownLatch]() -} - -/* - * 1. each thread contains one job. - * 2. each job contains one stage. - * 3. each stage only contains one task. - * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure - * it will get cpu core resource, and will wait to finished after user manually - * release "Lock" and then cluster will contain another free cpu cores. - * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, - * thus it will be scheduled later when cluster has free cpu cores. - */ -class LocalSchedulerSuite extends FunSuite with LocalSparkContext { - - def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { - - TaskThreadInfo.threadToRunning(threadIndex) = false - val nums = sc.parallelize(threadIndex to threadIndex, 1) - TaskThreadInfo.threadToLock(threadIndex) = new Lock() - TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) - new Thread { - if (poolName != null) { - sc.setLocalProperty("spark.scheduler.cluster.fair.pool", poolName) - } - override def run() { - val ans = nums.map(number => { - TaskThreadInfo.threadToRunning(number) = true - TaskThreadInfo.threadToStarted(number).countDown() - TaskThreadInfo.threadToLock(number).jobWait() - TaskThreadInfo.threadToRunning(number) = false - number - }).collect() - assert(ans.toList === List(threadIndex)) - sem.release() - } - }.start() - } - - test("Local FIFO scheduler end-to-end test") { - System.setProperty("spark.cluster.schedulingmode", "FIFO") - sc = new SparkContext("local[4]", "test") - val sem = new Semaphore(0) - - createThread(1,null,sc,sem) - TaskThreadInfo.threadToStarted(1).await() - createThread(2,null,sc,sem) - TaskThreadInfo.threadToStarted(2).await() - createThread(3,null,sc,sem) - TaskThreadInfo.threadToStarted(3).await() - createThread(4,null,sc,sem) - TaskThreadInfo.threadToStarted(4).await() - // thread 5 and 6 (stage pending)must meet following two points - // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager - // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() - // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 - // So I just use "sleep" 1s here for each thread. - // TODO: any better solution? - createThread(5,null,sc,sem) - Thread.sleep(1000) - createThread(6,null,sc,sem) - Thread.sleep(1000) - - assert(TaskThreadInfo.threadToRunning(1) === true) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === true) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === false) - assert(TaskThreadInfo.threadToRunning(6) === false) - - TaskThreadInfo.threadToLock(1).jobFinished() - TaskThreadInfo.threadToStarted(5).await() - - assert(TaskThreadInfo.threadToRunning(1) === false) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === true) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === true) - assert(TaskThreadInfo.threadToRunning(6) === false) - - TaskThreadInfo.threadToLock(3).jobFinished() - TaskThreadInfo.threadToStarted(6).await() - - assert(TaskThreadInfo.threadToRunning(1) === false) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === false) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === true) - assert(TaskThreadInfo.threadToRunning(6) === true) - - TaskThreadInfo.threadToLock(2).jobFinished() - TaskThreadInfo.threadToLock(4).jobFinished() - TaskThreadInfo.threadToLock(5).jobFinished() - TaskThreadInfo.threadToLock(6).jobFinished() - sem.acquire(6) - } - - test("Local fair scheduler end-to-end test") { - sc = new SparkContext("local[8]", "LocalSchedulerSuite") - val sem = new Semaphore(0) - System.setProperty("spark.cluster.schedulingmode", "FAIR") - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.fairscheduler.allocation.file", xmlPath) - - createThread(10,"1",sc,sem) - TaskThreadInfo.threadToStarted(10).await() - createThread(20,"2",sc,sem) - TaskThreadInfo.threadToStarted(20).await() - createThread(30,"3",sc,sem) - TaskThreadInfo.threadToStarted(30).await() - - assert(TaskThreadInfo.threadToRunning(10) === true) - assert(TaskThreadInfo.threadToRunning(20) === true) - assert(TaskThreadInfo.threadToRunning(30) === true) - - createThread(11,"1",sc,sem) - TaskThreadInfo.threadToStarted(11).await() - createThread(21,"2",sc,sem) - TaskThreadInfo.threadToStarted(21).await() - createThread(31,"3",sc,sem) - TaskThreadInfo.threadToStarted(31).await() - - assert(TaskThreadInfo.threadToRunning(11) === true) - assert(TaskThreadInfo.threadToRunning(21) === true) - assert(TaskThreadInfo.threadToRunning(31) === true) - - createThread(12,"1",sc,sem) - TaskThreadInfo.threadToStarted(12).await() - createThread(22,"2",sc,sem) - TaskThreadInfo.threadToStarted(22).await() - createThread(32,"3",sc,sem) - - assert(TaskThreadInfo.threadToRunning(12) === true) - assert(TaskThreadInfo.threadToRunning(22) === true) - assert(TaskThreadInfo.threadToRunning(32) === false) - - TaskThreadInfo.threadToLock(10).jobFinished() - TaskThreadInfo.threadToStarted(32).await() - - assert(TaskThreadInfo.threadToRunning(32) === true) - - //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager - // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. - //2. priority of 23 and 33 will be meaningless as using fair scheduler here. - createThread(23,"2",sc,sem) - createThread(33,"3",sc,sem) - Thread.sleep(1000) - - TaskThreadInfo.threadToLock(11).jobFinished() - TaskThreadInfo.threadToStarted(23).await() - - assert(TaskThreadInfo.threadToRunning(23) === true) - assert(TaskThreadInfo.threadToRunning(33) === false) - - TaskThreadInfo.threadToLock(12).jobFinished() - TaskThreadInfo.threadToStarted(33).await() - - assert(TaskThreadInfo.threadToRunning(33) === true) - - TaskThreadInfo.threadToLock(20).jobFinished() - TaskThreadInfo.threadToLock(21).jobFinished() - TaskThreadInfo.threadToLock(22).jobFinished() - TaskThreadInfo.threadToLock(23).jobFinished() - TaskThreadInfo.threadToLock(30).jobFinished() - TaskThreadInfo.threadToLock(31).jobFinished() - TaskThreadInfo.threadToLock(32).jobFinished() - TaskThreadInfo.threadToLock(33).jobFinished() - - sem.acquire(11) - } -} diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala deleted file mode 100644 index b719d65342..0000000000 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ /dev/null @@ -1,665 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer - -import akka.actor._ - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -import org.scalatest.PrivateMethodTester -import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.matchers.ShouldMatchers._ -import org.scalatest.time.SpanSugar._ - -import spark.JavaSerializer -import spark.KryoSerializer -import spark.SizeEstimator -import spark.util.AkkaUtils -import spark.util.ByteBufferInputStream - - -class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { - var store: BlockManager = null - var store2: BlockManager = null - var actorSystem: ActorSystem = null - var master: BlockManagerMaster = null - var oldArch: String = null - var oldOops: String = null - var oldHeartBeat: String = null - - // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - System.setProperty("spark.kryoserializer.buffer.mb", "1") - val serializer = new KryoSerializer - - before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) - this.actorSystem = actorSystem - System.setProperty("spark.driver.port", boundPort.toString) - System.setProperty("spark.hostPort", "localhost:" + boundPort) - - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true)))) - - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - oldArch = System.setProperty("os.arch", "amd64") - oldOops = System.setProperty("spark.test.useCompressedOops", "true") - oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true") - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - // Set some value ... - System.setProperty("spark.hostPort", spark.Utils.localHostName() + ":" + 1111) - } - - after { - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - - if (store != null) { - store.stop() - store = null - } - if (store2 != null) { - store2.stop() - store2 = null - } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null - master = null - - if (oldArch != null) { - System.setProperty("os.arch", oldArch) - } else { - System.clearProperty("os.arch") - } - - if (oldOops != null) { - System.setProperty("spark.test.useCompressedOops", oldOops) - } else { - System.clearProperty("spark.test.useCompressedOops") - } - } - - test("StorageLevel object caching") { - val level1 = StorageLevel(false, false, false, 3) - val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1 - val level3 = StorageLevel(false, false, false, 2) // this should return a different object - assert(level2 === level1, "level2 is not same as level1") - assert(level2.eq(level1), "level2 is not the same object as level1") - assert(level3 != level1, "level3 is same as level1") - val bytes1 = spark.Utils.serialize(level1) - val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) - val bytes2 = spark.Utils.serialize(level2) - val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) - assert(level1_ === level1, "Deserialized level1 not same as original level1") - assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2") - assert(level2_ === level2, "Deserialized level2 not same as original level2") - assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1") - } - - test("BlockManagerId object caching") { - val id1 = BlockManagerId("e1", "XXX", 1, 0) - val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1 - val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object - assert(id2 === id1, "id2 is not same as id1") - assert(id2.eq(id1), "id2 is not the same object as id1") - assert(id3 != id1, "id3 is same as id1") - val bytes1 = spark.Utils.serialize(id1) - val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) - val bytes2 = spark.Utils.serialize(id2) - val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2) - assert(id1_ === id1, "Deserialized id1 is not same as original id1") - assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1") - assert(id2_ === id2, "Deserialized id2 is not same as original id2") - assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") - } - - test("master + 1 manager interaction") { - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - - // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) - - // Checking whether blocks are in memory - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - - // Checking whether master knows about the blocks or not - assert(master.getLocations("a1").size > 0, "master was not told about a1") - assert(master.getLocations("a2").size > 0, "master was not told about a2") - assert(master.getLocations("a3").size === 0, "master was told about a3") - - // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) - assert(store.getSingle("a1") === None, "a1 not removed from store") - assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(master.getLocations("a1").size === 0, "master did not remove a1") - assert(master.getLocations("a2").size === 0, "master did not remove a2") - } - - test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000) - store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000) - - val peers = master.getPeers(store.blockManagerId, 1) - assert(peers.size === 1, "master did not return the other manager as a peer") - assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager") - - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) - store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) - assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") - assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") - } - - test("removing block") { - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - - // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 - store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) - - // Checking whether blocks are in memory and memory size - val memStatus = master.getMemoryStatus.head._2 - assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") - assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200") - assert(store.getSingle("a1-to-remove") != None, "a1 was not in store") - assert(store.getSingle("a2-to-remove") != None, "a2 was not in store") - assert(store.getSingle("a3-to-remove") != None, "a3 was not in store") - - // Checking whether master knows about the blocks or not - assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1") - assert(master.getLocations("a2-to-remove").size > 0, "master was not told about a2") - assert(master.getLocations("a3-to-remove").size === 0, "master was told about a3") - - // Remove a1 and a2 and a3. Should be no-op for a3. - master.removeBlock("a1-to-remove") - master.removeBlock("a2-to-remove") - master.removeBlock("a3-to-remove") - - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a1-to-remove") should be (None) - master.getLocations("a1-to-remove") should have size 0 - } - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a2-to-remove") should be (None) - master.getLocations("a2-to-remove") should have size 0 - } - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a3-to-remove") should not be (None) - master.getLocations("a3-to-remove") should have size 0 - } - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - val memStatus = master.getMemoryStatus.head._2 - memStatus._1 should equal (2000L) - memStatus._2 should equal (2000L) - } - } - - test("removing rdd") { - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - // Putting a1, a2 and a3 in memory. - store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) - master.removeRdd(0, blocking = false) - - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("rdd_0_0") should be (None) - master.getLocations("rdd_0_0") should have size 0 - } - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("rdd_0_1") should be (None) - master.getLocations("rdd_0_1") should have size 0 - } - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("nonrddblock") should not be (None) - master.getLocations("nonrddblock") should have size (1) - } - - store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) - master.removeRdd(0, blocking = true) - store.getSingle("rdd_0_0") should be (None) - master.getLocations("rdd_0_0") should have size 0 - store.getSingle("rdd_0_1") should be (None) - master.getLocations("rdd_0_1") should have size 0 - } - - test("reregistration on heart beat") { - val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(master.getLocations("a1").size > 0, "master was not told about a1") - - master.removeExecutor(store.blockManagerId.executorId) - assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - - store invokePrivate heartBeat() - assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") - } - - test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(master.getLocations("a1").size > 0, "master was not told about a1") - - master.removeExecutor(store.blockManagerId.executorId) - assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.waitForAsyncReregister() - - assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") - assert(master.getLocations("a2").size > 0, "master was not told about a2") - } - - test("reregistration doesn't dead lock") { - val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - val a2 = List(new Array[Byte](400)) - - // try many times to trigger any deadlocks - for (i <- 1 to 100) { - master.removeExecutor(store.blockManagerId.executorId) - val t1 = new Thread { - override def run() { - store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - } - } - val t2 = new Thread { - override def run() { - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - } - } - val t3 = new Thread { - override def run() { - store invokePrivate heartBeat() - } - } - - t1.start() - t2.start() - t3.start() - t1.join() - t2.join() - t3.join() - - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) - store.waitForAsyncReregister() - } - } - - test("in-memory LRU storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") - } - - test("in-memory LRU storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_SER) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") - } - - test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY) - // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 - // from the same RDD - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") - assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store") - // Check that rdd_0_3 doesn't replace them even after further accesses - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - } - - test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - // At this point rdd_1_1 should've replaced rdd_0_1 - assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store") - assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") - assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") - // Do a get() on rdd_0_2 so that it is the most recently used item - assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") - // Put in more partitions from RDD 0; they should replace rdd_1_1 - store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped - // when we try to add rdd_0_4. - assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store") - assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") - assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store") - assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") - assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store") - } - - test("on-disk storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.DISK_ONLY) - store.putSingle("a2", a2, StorageLevel.DISK_ONLY) - store.putSingle("a3", a3, StorageLevel.DISK_ONLY) - assert(store.getSingle("a2") != None, "a2 was in store") - assert(store.getSingle("a3") != None, "a3 was in store") - assert(store.getSingle("a1") != None, "a1 was in store") - } - - test("disk and memory storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") - } - - test("disk and memory storage with getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getLocalBytes("a2") != None, "a2 was not in store") - assert(store.getLocalBytes("a3") != None, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1") != None, "a1 was not in store") - assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") - } - - test("disk and memory storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") - } - - test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getLocalBytes("a2") != None, "a2 was not in store") - assert(store.getLocalBytes("a3") != None, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1") != None, "a1 was not in store") - assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") - } - - test("LRU with mixed storage levels") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - val a4 = new Array[Byte](400) - // First store a1 and a2, both in memory, and a3, on disk only - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a3", a3, StorageLevel.DISK_ONLY) - // At this point LRU should not kick in because a3 is only on disk - assert(store.getSingle("a1") != None, "a2 was not in store") - assert(store.getSingle("a2") != None, "a3 was not in store") - assert(store.getSingle("a3") != None, "a1 was not in store") - assert(store.getSingle("a1") != None, "a2 was not in store") - assert(store.getSingle("a2") != None, "a3 was not in store") - assert(store.getSingle("a3") != None, "a1 was not in store") - // Now let's add in a4, which uses both disk and memory; a1 should drop out - store.putSingle("a4", a4, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a1") == None, "a1 was in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a4") != None, "a4 was not in store") - } - - test("in-memory LRU with streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val list1 = List(new Array[Byte](200), new Array[Byte](200)) - val list2 = List(new Array[Byte](200), new Array[Byte](200)) - val list3 = List(new Array[Byte](200), new Array[Byte](200)) - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - assert(store.get("list2") != None, "list2 was not in store") - assert(store.get("list2").get.size == 2) - assert(store.get("list3") != None, "list3 was not in store") - assert(store.get("list3").get.size == 2) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2") != None, "list2 was not in store") - assert(store.get("list2").get.size == 2) - // At this point list2 was gotten last, so LRU will getSingle rid of list3 - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - assert(store.get("list1") != None, "list1 was not in store") - assert(store.get("list1").get.size == 2) - assert(store.get("list2") != None, "list2 was not in store") - assert(store.get("list2").get.size == 2) - assert(store.get("list3") === None, "list1 was in store") - } - - test("LRU with mixed storage levels and streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val list1 = List(new Array[Byte](200), new Array[Byte](200)) - val list2 = List(new Array[Byte](200), new Array[Byte](200)) - val list3 = List(new Array[Byte](200), new Array[Byte](200)) - val list4 = List(new Array[Byte](200), new Array[Byte](200)) - // First store list1 and list2, both in memory, and list3, on disk only - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) - // At this point LRU should not kick in because list3 is only on disk - assert(store.get("list1") != None, "list2 was not in store") - assert(store.get("list1").get.size === 2) - assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size === 2) - assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size === 2) - assert(store.get("list1") != None, "list2 was not in store") - assert(store.get("list1").get.size === 2) - assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size === 2) - assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size === 2) - // Now let's add in list4, which uses both disk and memory; list1 should drop out - store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size === 2) - assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size === 2) - assert(store.get("list4") != None, "list4 was not in store") - assert(store.get("list4").get.size === 2) - } - - test("negative byte values in ByteBufferInputStream") { - val buffer = ByteBuffer.wrap(Array[Int](254, 255, 0, 1, 2).map(_.toByte).toArray) - val stream = new ByteBufferInputStream(buffer) - val temp = new Array[Byte](10) - assert(stream.read() === 254, "unexpected byte read") - assert(stream.read() === 255, "unexpected byte read") - assert(stream.read() === 0, "unexpected byte read") - assert(stream.read(temp, 0, temp.length) === 2, "unexpected number of bytes read") - assert(stream.read() === -1, "end of stream not signalled") - assert(stream.read(temp, 0, temp.length) === -1, "end of stream not signalled") - } - - test("overly large block") { - store = new BlockManager("", actorSystem, master, serializer, 500) - store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1") === None, "a1 was in store") - store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) - assert(store.memoryStore.getValues("a2") === None, "a2 was in memory store") - assert(store.getSingle("a2") != None, "a2 was not in store") - } - - test("block compression") { - try { - System.setProperty("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000) - store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") - store.stop() - store = null - - System.setProperty("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000) - store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") - store.stop() - store = null - - System.setProperty("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000) - store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") - store.stop() - store = null - - System.setProperty("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000) - store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") - store.stop() - store = null - - System.setProperty("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000) - store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") - store.stop() - store = null - - System.setProperty("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000) - store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") - store.stop() - store = null - - // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000) - store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) - assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") - store.stop() - store = null - } finally { - System.clearProperty("spark.shuffle.compress") - System.clearProperty("spark.broadcast.compress") - System.clearProperty("spark.rdd.compress") - } - } - - test("block store put failure") { - // Use Java serializer so we can create an unserializable error. - store = new BlockManager("", actorSystem, master, new JavaSerializer, 1200) - - // The put should fail since a1 is not serializable. - class UnserializableClass - val a1 = new UnserializableClass - intercept[java.io.NotSerializableException] { - store.putSingle("a1", a1, StorageLevel.DISK_ONLY) - } - - // Make sure get a1 doesn't hang and returns None. - failAfter(1 second) { - assert(store.getSingle("a1") == None, "a1 should not be in store") - } - } -} diff --git a/core/src/test/scala/spark/ui/UISuite.scala b/core/src/test/scala/spark/ui/UISuite.scala deleted file mode 100644 index 735a794396..0000000000 --- a/core/src/test/scala/spark/ui/UISuite.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -import scala.util.{Failure, Success, Try} -import java.net.ServerSocket -import org.scalatest.FunSuite -import org.eclipse.jetty.server.Server - -class UISuite extends FunSuite { - test("jetty port increases under contention") { - val startPort = 3030 - val server = new Server(startPort) - server.start() - val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq()) - val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq()) - - // Allow some wiggle room in case ports on the machine are under contention - assert(boundPort1 > startPort && boundPort1 < startPort + 10) - assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10) - } - - test("jetty binds to port 0 correctly") { - val (jettyServer, boundPort) = JettyUtils.startJettyServer("localhost", 0, Seq()) - assert(jettyServer.getState === "STARTED") - assert(boundPort != 0) - Try {new ServerSocket(boundPort)} match { - case Success(s) => fail("Port %s doesn't seem used by jetty server".format(boundPort)) - case Failure (e) => - } - } -} diff --git a/core/src/test/scala/spark/util/DistributionSuite.scala b/core/src/test/scala/spark/util/DistributionSuite.scala deleted file mode 100644 index 6578b55e82..0000000000 --- a/core/src/test/scala/spark/util/DistributionSuite.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers - -/** - * - */ - -class DistributionSuite extends FunSuite with ShouldMatchers { - test("summary") { - val d = new Distribution((1 to 100).toArray.map{_.toDouble}) - val stats = d.statCounter - stats.count should be (100) - stats.mean should be (50.5) - stats.sum should be (50 * 101) - - val quantiles = d.getQuantiles() - quantiles(0) should be (1) - quantiles(1) should be (26) - quantiles(2) should be (51) - quantiles(3) should be (76) - quantiles(4) should be (100) - } -} diff --git a/core/src/test/scala/spark/util/FakeClock.scala b/core/src/test/scala/spark/util/FakeClock.scala deleted file mode 100644 index 236706317e..0000000000 --- a/core/src/test/scala/spark/util/FakeClock.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -class FakeClock extends Clock { - private var time = 0L - - def advance(millis: Long): Unit = time += millis - - def getTime(): Long = time -} diff --git a/core/src/test/scala/spark/util/NextIteratorSuite.scala b/core/src/test/scala/spark/util/NextIteratorSuite.scala deleted file mode 100644 index fdbd43d941..0000000000 --- a/core/src/test/scala/spark/util/NextIteratorSuite.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers -import scala.collection.mutable.Buffer -import java.util.NoSuchElementException - -class NextIteratorSuite extends FunSuite with ShouldMatchers { - test("one iteration") { - val i = new StubIterator(Buffer(1)) - i.hasNext should be === true - i.next should be === 1 - i.hasNext should be === false - intercept[NoSuchElementException] { i.next() } - } - - test("two iterations") { - val i = new StubIterator(Buffer(1, 2)) - i.hasNext should be === true - i.next should be === 1 - i.hasNext should be === true - i.next should be === 2 - i.hasNext should be === false - intercept[NoSuchElementException] { i.next() } - } - - test("empty iteration") { - val i = new StubIterator(Buffer()) - i.hasNext should be === false - intercept[NoSuchElementException] { i.next() } - } - - test("close is called once for empty iterations") { - val i = new StubIterator(Buffer()) - i.hasNext should be === false - i.hasNext should be === false - i.closeCalled should be === 1 - } - - test("close is called once for non-empty iterations") { - val i = new StubIterator(Buffer(1, 2)) - i.next should be === 1 - i.next should be === 2 - // close isn't called until we check for the next element - i.closeCalled should be === 0 - i.hasNext should be === false - i.closeCalled should be === 1 - i.hasNext should be === false - i.closeCalled should be === 1 - } - - class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { - var closeCalled = 0 - - override def getNext() = { - if (ints.size == 0) { - finished = true - 0 - } else { - ints.remove(0) - } - } - - override def close() { - closeCalled += 1 - } - } -} diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala deleted file mode 100644 index 4c0044202f..0000000000 --- a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import org.scalatest.FunSuite -import java.io.ByteArrayOutputStream -import java.util.concurrent.TimeUnit._ - -class RateLimitedOutputStreamSuite extends FunSuite { - - private def benchmark[U](f: => U): Long = { - val start = System.nanoTime - f - System.nanoTime - start - } - - test("write") { - val underlying = new ByteArrayOutputStream - val data = "X" * 41000 - val stream = new RateLimitedOutputStream(underlying, 10000) - val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } - assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) - assert(underlying.toString("UTF-8") == data) - } -} diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 84749fda4e..349eb92a47 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -100,7 +100,7 @@
  • Tuning Guide
  • Hardware Provisioning
  • Building Spark with Maven
  • -
  • Contributing to Spark
  • +
  • Contributing to Spark
  • diff --git a/examples/pom.xml b/examples/pom.xml index 687fbcca8f..13b5531511 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-examples jar Spark Project Examples @@ -33,25 +33,25 @@ - org.spark-project + org.apache.spark spark-core ${project.version} provided - org.spark-project + org.apache.spark spark-streaming ${project.version} provided - org.spark-project + org.apache.spark spark-mllib ${project.version} provided - org.spark-project + org.apache.spark spark-bagel ${project.version} provided @@ -132,7 +132,7 @@ hadoop2-yarn - org.spark-project + org.apache.spark spark-yarn ${project.version} provided diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java new file mode 100644 index 0000000000..be0d38589c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.StringTokenizer; +import java.util.Random; + +/** + * Logistic regression based classification. + */ +public class JavaHdfsLR { + + static int D = 10; // Number of dimensions + static Random rand = new Random(42); + + static class DataPoint implements Serializable { + public DataPoint(double[] x, double y) { + this.x = x; + this.y = y; + } + + double[] x; + double y; + } + + static class ParsePoint extends Function { + public DataPoint call(String line) { + StringTokenizer tok = new StringTokenizer(line, " "); + double y = Double.parseDouble(tok.nextToken()); + double[] x = new double[D]; + int i = 0; + while (i < D) { + x[i] = Double.parseDouble(tok.nextToken()); + i += 1; + } + return new DataPoint(x, y); + } + } + + static class VectorSum extends Function2 { + public double[] call(double[] a, double[] b) { + double[] result = new double[D]; + for (int j = 0; j < D; j++) { + result[j] = a[j] + b[j]; + } + return result; + } + } + + static class ComputeGradient extends Function { + double[] weights; + + public ComputeGradient(double[] weights) { + this.weights = weights; + } + + public double[] call(DataPoint p) { + double[] gradient = new double[D]; + for (int i = 0; i < D; i++) { + double dot = dot(weights, p.x); + gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i]; + } + return gradient; + } + } + + public static double dot(double[] a, double[] b) { + double x = 0; + for (int i = 0; i < D; i++) { + x += a[i] * b[i]; + } + return x; + } + + public static void printWeights(double[] a) { + System.out.println(Arrays.toString(a)); + } + + public static void main(String[] args) { + + if (args.length < 3) { + System.err.println("Usage: JavaHdfsLR "); + System.exit(1); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaHdfsLR", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = sc.textFile(args[1]); + JavaRDD points = lines.map(new ParsePoint()).cache(); + int ITERATIONS = Integer.parseInt(args[2]); + + // Initialize w to a random value + double[] w = new double[D]; + for (int i = 0; i < D; i++) { + w[i] = 2 * rand.nextDouble() - 1; + } + + System.out.print("Initial w: "); + printWeights(w); + + for (int i = 1; i <= ITERATIONS; i++) { + System.out.println("On iteration " + i); + + double[] gradient = points.map( + new ComputeGradient(w) + ).reduce(new VectorSum()); + + for (int j = 0; j < D; j++) { + w[j] -= gradient[j]; + } + + } + + System.out.print("Final w: "); + printWeights(w); + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java new file mode 100644 index 0000000000..5a6afe7eae --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import scala.Tuple2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.util.Vector; + +import java.util.List; +import java.util.Map; + +/** + * K-means clustering using Java API. + */ +public class JavaKMeans { + + /** Parses numbers split by whitespace to a vector */ + static Vector parseVector(String line) { + String[] splits = line.split(" "); + double[] data = new double[splits.length]; + int i = 0; + for (String s : splits) + data[i] = Double.parseDouble(splits[i++]); + return new Vector(data); + } + + /** Computes the vector to which the input vector is closest using squared distance */ + static int closestPoint(Vector p, List centers) { + int bestIndex = 0; + double closest = Double.POSITIVE_INFINITY; + for (int i = 0; i < centers.size(); i++) { + double tempDist = p.squaredDist(centers.get(i)); + if (tempDist < closest) { + closest = tempDist; + bestIndex = i; + } + } + return bestIndex; + } + + /** Computes the mean across all vectors in the input set of vectors */ + static Vector average(List ps) { + int numVectors = ps.size(); + Vector out = new Vector(ps.get(0).elements()); + // start from i = 1 since we already copied index 0 above + for (int i = 1; i < numVectors; i++) { + out.addInPlace(ps.get(i)); + } + return out.divide(numVectors); + } + + public static void main(String[] args) throws Exception { + if (args.length < 4) { + System.err.println("Usage: JavaKMeans "); + System.exit(1); + } + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + String path = args[1]; + int K = Integer.parseInt(args[2]); + double convergeDist = Double.parseDouble(args[3]); + + JavaRDD data = sc.textFile(path).map( + new Function() { + @Override + public Vector call(String line) throws Exception { + return parseVector(line); + } + } + ).cache(); + + final List centroids = data.takeSample(false, K, 42); + + double tempDist; + do { + // allocate each vector to closest centroid + JavaPairRDD closest = data.map( + new PairFunction() { + @Override + public Tuple2 call(Vector vector) throws Exception { + return new Tuple2( + closestPoint(vector, centroids), vector); + } + } + ); + + // group by cluster id and average the vectors within each cluster to compute centroids + JavaPairRDD> pointsGroup = closest.groupByKey(); + Map newCentroids = pointsGroup.mapValues( + new Function, Vector>() { + public Vector call(List ps) throws Exception { + return average(ps); + } + }).collectAsMap(); + tempDist = 0.0; + for (int i = 0; i < K; i++) { + tempDist += centroids.get(i).squaredDist(newCentroids.get(i)); + } + for (Map.Entry t: newCentroids.entrySet()) { + centroids.set(t.getKey(), t.getValue()); + } + System.out.println("Finished iteration (delta = " + tempDist + ")"); + } while (tempDist > convergeDist); + + System.out.println("Final centers:"); + for (Vector c : centroids) + System.out.println(c); + + System.exit(0); + + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java new file mode 100644 index 0000000000..152f029213 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import com.google.common.collect.Lists; +import scala.Tuple2; +import scala.Tuple3; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; + +import java.io.Serializable; +import java.util.Collections; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Executes a roll up-style query against Apache logs. + */ +public class JavaLogQuery { + + public static List exampleApacheLogs = Lists.newArrayList( + "10.10.10.10 - \"FRED\" [18/Jan/2013:17:56:07 +1100] \"GET http://images.com/2013/Generic.jpg " + + "HTTP/1.1\" 304 315 \"http://referall.com/\" \"Mozilla/4.0 (compatible; MSIE 7.0; " + + "Windows NT 5.1; GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; " + + ".NET CLR 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR " + + "3.5.30729; Release=ARP)\" \"UD-1\" - \"image/jpeg\" \"whatever\" 0.350 \"-\" - \"\" 265 923 934 \"\" " + + "62.24.11.25 images.com 1358492167 - Whatup", + "10.10.10.10 - \"FRED\" [18/Jan/2013:18:02:37 +1100] \"GET http://images.com/2013/Generic.jpg " + + "HTTP/1.1\" 304 306 \"http:/referall.com\" \"Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; " + + "GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR " + + "3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR " + + "3.5.30729; Release=ARP)\" \"UD-1\" - \"image/jpeg\" \"whatever\" 0.352 \"-\" - \"\" 256 977 988 \"\" " + + "0 73.23.2.15 images.com 1358492557 - Whatup"); + + public static Pattern apacheLogRegex = Pattern.compile( + "^([\\d.]+) (\\S+) (\\S+) \\[([\\w\\d:/]+\\s[+\\-]\\d{4})\\] \"(.+?)\" (\\d{3}) ([\\d\\-]+) \"([^\"]+)\" \"([^\"]+)\".*"); + + /** Tracks the total query count and number of aggregate bytes for a particular group. */ + public static class Stats implements Serializable { + + private int count; + private int numBytes; + + public Stats(int count, int numBytes) { + this.count = count; + this.numBytes = numBytes; + } + public Stats merge(Stats other) { + return new Stats(count + other.count, numBytes + other.numBytes); + } + + public String toString() { + return String.format("bytes=%s\tn=%s", numBytes, count); + } + } + + public static Tuple3 extractKey(String line) { + Matcher m = apacheLogRegex.matcher(line); + List key = Collections.emptyList(); + if (m.find()) { + String ip = m.group(1); + String user = m.group(3); + String query = m.group(5); + if (!user.equalsIgnoreCase("-")) { + return new Tuple3(ip, user, query); + } + } + return new Tuple3(null, null, null); + } + + public static Stats extractStats(String line) { + Matcher m = apacheLogRegex.matcher(line); + if (m.find()) { + int bytes = Integer.parseInt(m.group(7)); + return new Stats(1, bytes); + } + else + return new Stats(1, 0); + } + + public static void main(String[] args) throws Exception { + if (args.length == 0) { + System.err.println("Usage: JavaLogQuery [logFile]"); + System.exit(1); + } + + JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaLogQuery", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + JavaRDD dataSet = (args.length == 2) ? jsc.textFile(args[1]) : jsc.parallelize(exampleApacheLogs); + + JavaPairRDD, Stats> extracted = dataSet.map(new PairFunction, Stats>() { + @Override + public Tuple2, Stats> call(String s) throws Exception { + return new Tuple2, Stats>(extractKey(s), extractStats(s)); + } + }); + + JavaPairRDD, Stats> counts = extracted.reduceByKey(new Function2() { + @Override + public Stats call(Stats stats, Stats stats2) throws Exception { + return stats.merge(stats2); + } + }); + + List, Stats>> output = counts.collect(); + for (Tuple2 t : output) { + System.out.println(t._1 + "\t" + t._2); + } + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java new file mode 100644 index 0000000000..c5603a639b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import scala.Tuple2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.spark.api.java.function.PairFunction; + +import java.util.List; +import java.util.ArrayList; + +/** + * Computes the PageRank of URLs from an input file. Input file should + * be in format of: + * URL neighbor URL + * URL neighbor URL + * URL neighbor URL + * ... + * where URL and their neighbors are separated by space(s). + */ +public class JavaPageRank { + private static class Sum extends Function2 { + @Override + public Double call(Double a, Double b) { + return a + b; + } + } + + public static void main(String[] args) throws Exception { + if (args.length < 3) { + System.err.println("Usage: JavaPageRank "); + System.exit(1); + } + + JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaPageRank", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + // Loads in input file. It should be in format of: + // URL neighbor URL + // URL neighbor URL + // URL neighbor URL + // ... + JavaRDD lines = ctx.textFile(args[1], 1); + + // Loads all URLs from input file and initialize their neighbors. + JavaPairRDD> links = lines.map(new PairFunction() { + @Override + public Tuple2 call(String s) { + String[] parts = s.split("\\s+"); + return new Tuple2(parts[0], parts[1]); + } + }).distinct().groupByKey().cache(); + + // Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. + JavaPairRDD ranks = links.mapValues(new Function, Double>() { + @Override + public Double call(List rs) throws Exception { + return 1.0; + } + }); + + // Calculates and updates URL ranks continuously using PageRank algorithm. + for (int current = 0; current < Integer.parseInt(args[2]); current++) { + // Calculates URL contributions to the rank of other URLs. + JavaPairRDD contribs = links.join(ranks).values() + .flatMap(new PairFlatMapFunction, Double>, String, Double>() { + @Override + public Iterable> call(Tuple2, Double> s) { + List> results = new ArrayList>(); + for (String n : s._1) { + results.add(new Tuple2(n, s._2 / s._1.size())); + } + return results; + } + }); + + // Re-calculates URL ranks based on neighbor contributions. + ranks = contribs.reduceByKey(new Sum()).mapValues(new Function() { + @Override + public Double call(Double sum) throws Exception { + return 0.15 + sum * 0.85; + } + }); + } + + // Collects all URL ranks and dump them to console. + List> output = ranks.collect(); + for (Tuple2 tuple : output) { + System.out.println(tuple._1 + " has rank: " + tuple._2 + "."); + } + + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java new file mode 100644 index 0000000000..4a2380caf5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; + +import java.util.ArrayList; +import java.util.List; + +/** Computes an approximation to pi */ +public class JavaSparkPi { + + + public static void main(String[] args) throws Exception { + if (args.length == 0) { + System.err.println("Usage: JavaLogQuery [slices]"); + System.exit(1); + } + + JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaLogQuery", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + int slices = (args.length == 2) ? Integer.parseInt(args[1]) : 2; + int n = 100000 * slices; + List l = new ArrayList(n); + for (int i = 0; i < n; i++) + l.add(i); + + JavaRDD dataSet = jsc.parallelize(l, slices); + + int count = dataSet.map(new Function() { + @Override + public Integer call(Integer integer) throws Exception { + double x = Math.random() * 2 - 1; + double y = Math.random() * 2 - 1; + return (x * x + y * y < 1) ? 1 : 0; + } + }).reduce(new Function2() { + @Override + public Integer call(Integer integer, Integer integer2) throws Exception { + return integer + integer2; + } + }); + + System.out.println("Pi is roughly " + 4.0 * count / n); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java new file mode 100644 index 0000000000..17f21f6b77 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import scala.Tuple2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.PairFunction; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; + +/** + * Transitive closure on a graph, implemented in Java. + */ +public class JavaTC { + + static int numEdges = 200; + static int numVertices = 100; + static Random rand = new Random(42); + + static List> generateGraph() { + Set> edges = new HashSet>(numEdges); + while (edges.size() < numEdges) { + int from = rand.nextInt(numVertices); + int to = rand.nextInt(numVertices); + Tuple2 e = new Tuple2(from, to); + if (from != to) edges.add(e); + } + return new ArrayList>(edges); + } + + static class ProjectFn extends PairFunction>, + Integer, Integer> { + static ProjectFn INSTANCE = new ProjectFn(); + + public Tuple2 call(Tuple2> triple) { + return new Tuple2(triple._2()._2(), triple._2()._1()); + } + } + + public static void main(String[] args) { + if (args.length == 0) { + System.err.println("Usage: JavaTC []"); + System.exit(1); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaTC", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + Integer slices = (args.length > 1) ? Integer.parseInt(args[1]): 2; + JavaPairRDD tc = sc.parallelizePairs(generateGraph(), slices).cache(); + + // Linear transitive closure: each round grows paths by one edge, + // by joining the graph's edges with the already-discovered paths. + // e.g. join the path (y, z) from the TC with the edge (x, y) from + // the graph to obtain the path (x, z). + + // Because join() joins on keys, the edges are stored in reversed order. + JavaPairRDD edges = tc.map( + new PairFunction, Integer, Integer>() { + public Tuple2 call(Tuple2 e) { + return new Tuple2(e._2(), e._1()); + } + }); + + long oldCount = 0; + long nextCount = tc.count(); + do { + oldCount = nextCount; + // Perform the join, obtaining an RDD of (y, (z, x)) pairs, + // then project the result to obtain the new (x, z) paths. + tc = tc.union(tc.join(edges).map(ProjectFn.INSTANCE)).distinct().cache(); + nextCount = tc.count(); + } while (nextCount != oldCount); + + System.out.println("TC has " + tc.count() + " edges."); + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java new file mode 100644 index 0000000000..07d32ad659 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import scala.Tuple2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; + +import java.util.Arrays; +import java.util.List; + +public class JavaWordCount { + public static void main(String[] args) throws Exception { + if (args.length < 2) { + System.err.println("Usage: JavaWordCount "); + System.exit(1); + } + + JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaWordCount", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = ctx.textFile(args[1], 1); + + JavaRDD words = lines.flatMap(new FlatMapFunction() { + public Iterable call(String s) { + return Arrays.asList(s.split(" ")); + } + }); + + JavaPairRDD ones = words.map(new PairFunction() { + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }); + + JavaPairRDD counts = ones.reduceByKey(new Function2() { + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + + List> output = counts.collect(); + for (Tuple2 tuple : output) { + System.out.println(tuple._1 + ": " + tuple._2); + } + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java new file mode 100644 index 0000000000..628cb892b6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.examples; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.StringTokenizer; + +import scala.Tuple2; + +/** + * Example using MLLib ALS from Java. + */ +public class JavaALS { + + static class ParseRating extends Function { + public Rating call(String line) { + StringTokenizer tok = new StringTokenizer(line, ","); + int x = Integer.parseInt(tok.nextToken()); + int y = Integer.parseInt(tok.nextToken()); + double rating = Double.parseDouble(tok.nextToken()); + return new Rating(x, y, rating); + } + } + + static class FeaturesToString extends Function, String> { + public String call(Tuple2 element) { + return element._1().toString() + "," + Arrays.toString(element._2()); + } + } + + public static void main(String[] args) { + + if (args.length != 5 && args.length != 6) { + System.err.println( + "Usage: JavaALS []"); + System.exit(1); + } + + int rank = Integer.parseInt(args[2]); + int iterations = Integer.parseInt(args[3]); + String outputDir = args[4]; + int blocks = -1; + if (args.length == 6) { + blocks = Integer.parseInt(args[5]); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = sc.textFile(args[1]); + + JavaRDD ratings = lines.map(new ParseRating()); + + MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks); + + model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( + outputDir + "/userFeatures"); + model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( + outputDir + "/productFeatures"); + System.out.println("Final user/product features written to " + outputDir); + + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java new file mode 100644 index 0000000000..cd59a139b9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.examples; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +import org.apache.spark.mllib.clustering.KMeans; +import org.apache.spark.mllib.clustering.KMeansModel; + +import java.util.Arrays; +import java.util.StringTokenizer; + +/** + * Example using MLLib KMeans from Java. + */ +public class JavaKMeans { + + static class ParsePoint extends Function { + public double[] call(String line) { + StringTokenizer tok = new StringTokenizer(line, " "); + int numTokens = tok.countTokens(); + double[] point = new double[numTokens]; + for (int i = 0; i < numTokens; ++i) { + point[i] = Double.parseDouble(tok.nextToken()); + } + return point; + } + } + + public static void main(String[] args) { + + if (args.length < 4) { + System.err.println( + "Usage: JavaKMeans []"); + System.exit(1); + } + + String inputFile = args[1]; + int k = Integer.parseInt(args[2]); + int iterations = Integer.parseInt(args[3]); + int runs = 1; + + if (args.length >= 5) { + runs = Integer.parseInt(args[4]); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = sc.textFile(args[1]); + + JavaRDD points = lines.map(new ParsePoint()); + + KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs); + + System.out.println("Cluster centers:"); + for (double[] center : model.clusterCenters()) { + System.out.println(" " + Arrays.toString(center)); + } + double cost = model.computeCost(points.rdd()); + System.out.println("Cost: " + cost); + + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java new file mode 100644 index 0000000000..258061c8e6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.examples; + + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.regression.LabeledPoint; + +import java.util.Arrays; +import java.util.StringTokenizer; + +/** + * Logistic regression based classification using ML Lib. + */ +public class JavaLR { + + static class ParsePoint extends Function { + public LabeledPoint call(String line) { + String[] parts = line.split(","); + double y = Double.parseDouble(parts[0]); + StringTokenizer tok = new StringTokenizer(parts[1], " "); + int numTokens = tok.countTokens(); + double[] x = new double[numTokens]; + for (int i = 0; i < numTokens; ++i) { + x[i] = Double.parseDouble(tok.nextToken()); + } + return new LabeledPoint(y, x); + } + } + + public static void printWeights(double[] a) { + System.out.println(Arrays.toString(a)); + } + + public static void main(String[] args) { + if (args.length != 4) { + System.err.println("Usage: JavaLR "); + System.exit(1); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaLR", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = sc.textFile(args[1]); + JavaRDD points = lines.map(new ParsePoint()).cache(); + double stepSize = Double.parseDouble(args[2]); + int iterations = Integer.parseInt(args[3]); + + // Another way to configure LogisticRegression + // + // LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD(); + // lr.optimizer().setNumIterations(iterations) + // .setStepSize(stepSize) + // .setMiniBatchFraction(1.0); + // lr.setIntercept(true); + // LogisticRegressionModel model = lr.train(points.rdd()); + + LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), + iterations, stepSize); + + System.out.print("Final w: "); + printWeights(model.weights()); + + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java new file mode 100644 index 0000000000..261813bf2f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples; + +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.dstream.SparkFlumeEvent; + +/** + * Produces a count of events received from Flume. + * + * This should be used in conjunction with an AvroSink in Flume. It will start + * an Avro server on at the request host:port address and listen for requests. + * Your Flume AvroSink should be pointed to this address. + * + * Usage: JavaFlumeEventCount + * + * is a Spark master URL + * is the host the Flume receiver will be started on - a receiver + * creates a server and listens for flume events. + * is the port the Flume receiver will listen on. + */ +public class JavaFlumeEventCount { + public static void main(String[] args) { + if (args.length != 3) { + System.err.println("Usage: JavaFlumeEventCount "); + System.exit(1); + } + + String master = args[0]; + String host = args[1]; + int port = Integer.parseInt(args[2]); + + Duration batchInterval = new Duration(2000); + + JavaStreamingContext sc = new JavaStreamingContext(master, "FlumeEventCount", batchInterval, + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + JavaDStream flumeStream = sc.flumeStream("localhost", port); + + flumeStream.count(); + + flumeStream.count().map(new Function() { + @Override + public String call(Long in) { + return "Received " + in + " flume events."; + } + }).print(); + + sc.start(); + } +} diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java new file mode 100644 index 0000000000..def87c199b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples; + +import com.google.common.collect.Lists; +import scala.Tuple2; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: NetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run spark.streaming.examples.JavaNetworkWordCount local[2] localhost 9999` + */ +public class JavaNetworkWordCount { + public static void main(String[] args) { + if (args.length < 3) { + System.err.println("Usage: NetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1"); + System.exit(1); + } + + // Create the context with a 1 second batch size + JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount", + new Duration(1000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + JavaDStream lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2])); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(x.split(" ")); + } + }); + JavaPairDStream wordCounts = words.map( + new PairFunction() { + @Override + public Tuple2 call(String s) throws Exception { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + }); + + wordCounts.print(); + ssc.start(); + + } +} diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java new file mode 100644 index 0000000000..c8c7389dd1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples; + +import com.google.common.collect.Lists; +import scala.Tuple2; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; + +public class JavaQueueStream { + public static void main(String[] args) throws InterruptedException { + if (args.length < 1) { + System.err.println("Usage: JavaQueueStream "); + System.exit(1); + } + + // Create the context + JavaStreamingContext ssc = new JavaStreamingContext(args[0], "QueueStream", new Duration(1000), + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + // Create the queue through which RDDs can be pushed to + // a QueueInputDStream + Queue> rddQueue = new LinkedList>(); + + // Create and push some RDDs into the queue + List list = Lists.newArrayList(); + for (int i = 0; i < 1000; i++) { + list.add(i); + } + + for (int i = 0; i < 30; i++) { + rddQueue.add(ssc.sc().parallelize(list)); + } + + + // Create the QueueInputDStream and use it do some processing + JavaDStream inputStream = ssc.queueStream(rddQueue); + JavaPairDStream mappedStream = inputStream.map( + new PairFunction() { + @Override + public Tuple2 call(Integer i) throws Exception { + return new Tuple2(i % 10, 1); + } + }); + JavaPairDStream reducedStream = mappedStream.reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + }); + + reducedStream.print(); + ssc.start(); + } +} diff --git a/examples/src/main/java/spark/examples/JavaHdfsLR.java b/examples/src/main/java/spark/examples/JavaHdfsLR.java deleted file mode 100644 index 9485e0cfa9..0000000000 --- a/examples/src/main/java/spark/examples/JavaHdfsLR.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; -import spark.api.java.function.Function2; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.StringTokenizer; -import java.util.Random; - -/** - * Logistic regression based classification. - */ -public class JavaHdfsLR { - - static int D = 10; // Number of dimensions - static Random rand = new Random(42); - - static class DataPoint implements Serializable { - public DataPoint(double[] x, double y) { - this.x = x; - this.y = y; - } - - double[] x; - double y; - } - - static class ParsePoint extends Function { - public DataPoint call(String line) { - StringTokenizer tok = new StringTokenizer(line, " "); - double y = Double.parseDouble(tok.nextToken()); - double[] x = new double[D]; - int i = 0; - while (i < D) { - x[i] = Double.parseDouble(tok.nextToken()); - i += 1; - } - return new DataPoint(x, y); - } - } - - static class VectorSum extends Function2 { - public double[] call(double[] a, double[] b) { - double[] result = new double[D]; - for (int j = 0; j < D; j++) { - result[j] = a[j] + b[j]; - } - return result; - } - } - - static class ComputeGradient extends Function { - double[] weights; - - public ComputeGradient(double[] weights) { - this.weights = weights; - } - - public double[] call(DataPoint p) { - double[] gradient = new double[D]; - for (int i = 0; i < D; i++) { - double dot = dot(weights, p.x); - gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i]; - } - return gradient; - } - } - - public static double dot(double[] a, double[] b) { - double x = 0; - for (int i = 0; i < D; i++) { - x += a[i] * b[i]; - } - return x; - } - - public static void printWeights(double[] a) { - System.out.println(Arrays.toString(a)); - } - - public static void main(String[] args) { - - if (args.length < 3) { - System.err.println("Usage: JavaHdfsLR "); - System.exit(1); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaHdfsLR", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - JavaRDD lines = sc.textFile(args[1]); - JavaRDD points = lines.map(new ParsePoint()).cache(); - int ITERATIONS = Integer.parseInt(args[2]); - - // Initialize w to a random value - double[] w = new double[D]; - for (int i = 0; i < D; i++) { - w[i] = 2 * rand.nextDouble() - 1; - } - - System.out.print("Initial w: "); - printWeights(w); - - for (int i = 1; i <= ITERATIONS; i++) { - System.out.println("On iteration " + i); - - double[] gradient = points.map( - new ComputeGradient(w) - ).reduce(new VectorSum()); - - for (int j = 0; j < D; j++) { - w[j] -= gradient[j]; - } - - } - - System.out.print("Final w: "); - printWeights(w); - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/examples/JavaKMeans.java b/examples/src/main/java/spark/examples/JavaKMeans.java deleted file mode 100644 index 2d34776177..0000000000 --- a/examples/src/main/java/spark/examples/JavaKMeans.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import scala.Tuple2; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; -import spark.api.java.function.PairFunction; -import spark.util.Vector; - -import java.util.List; -import java.util.Map; - -/** - * K-means clustering using Java API. - */ -public class JavaKMeans { - - /** Parses numbers split by whitespace to a vector */ - static Vector parseVector(String line) { - String[] splits = line.split(" "); - double[] data = new double[splits.length]; - int i = 0; - for (String s : splits) - data[i] = Double.parseDouble(splits[i++]); - return new Vector(data); - } - - /** Computes the vector to which the input vector is closest using squared distance */ - static int closestPoint(Vector p, List centers) { - int bestIndex = 0; - double closest = Double.POSITIVE_INFINITY; - for (int i = 0; i < centers.size(); i++) { - double tempDist = p.squaredDist(centers.get(i)); - if (tempDist < closest) { - closest = tempDist; - bestIndex = i; - } - } - return bestIndex; - } - - /** Computes the mean across all vectors in the input set of vectors */ - static Vector average(List ps) { - int numVectors = ps.size(); - Vector out = new Vector(ps.get(0).elements()); - // start from i = 1 since we already copied index 0 above - for (int i = 1; i < numVectors; i++) { - out.addInPlace(ps.get(i)); - } - return out.divide(numVectors); - } - - public static void main(String[] args) throws Exception { - if (args.length < 4) { - System.err.println("Usage: JavaKMeans "); - System.exit(1); - } - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - String path = args[1]; - int K = Integer.parseInt(args[2]); - double convergeDist = Double.parseDouble(args[3]); - - JavaRDD data = sc.textFile(path).map( - new Function() { - @Override - public Vector call(String line) throws Exception { - return parseVector(line); - } - } - ).cache(); - - final List centroids = data.takeSample(false, K, 42); - - double tempDist; - do { - // allocate each vector to closest centroid - JavaPairRDD closest = data.map( - new PairFunction() { - @Override - public Tuple2 call(Vector vector) throws Exception { - return new Tuple2( - closestPoint(vector, centroids), vector); - } - } - ); - - // group by cluster id and average the vectors within each cluster to compute centroids - JavaPairRDD> pointsGroup = closest.groupByKey(); - Map newCentroids = pointsGroup.mapValues( - new Function, Vector>() { - public Vector call(List ps) throws Exception { - return average(ps); - } - }).collectAsMap(); - tempDist = 0.0; - for (int i = 0; i < K; i++) { - tempDist += centroids.get(i).squaredDist(newCentroids.get(i)); - } - for (Map.Entry t: newCentroids.entrySet()) { - centroids.set(t.getKey(), t.getValue()); - } - System.out.println("Finished iteration (delta = " + tempDist + ")"); - } while (tempDist > convergeDist); - - System.out.println("Final centers:"); - for (Vector c : centroids) - System.out.println(c); - - System.exit(0); - - } -} diff --git a/examples/src/main/java/spark/examples/JavaLogQuery.java b/examples/src/main/java/spark/examples/JavaLogQuery.java deleted file mode 100644 index d22684d980..0000000000 --- a/examples/src/main/java/spark/examples/JavaLogQuery.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import com.google.common.collect.Lists; -import scala.Tuple2; -import scala.Tuple3; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFunction; - -import java.io.Serializable; -import java.util.Collections; -import java.util.List; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -/** - * Executes a roll up-style query against Apache logs. - */ -public class JavaLogQuery { - - public static List exampleApacheLogs = Lists.newArrayList( - "10.10.10.10 - \"FRED\" [18/Jan/2013:17:56:07 +1100] \"GET http://images.com/2013/Generic.jpg " + - "HTTP/1.1\" 304 315 \"http://referall.com/\" \"Mozilla/4.0 (compatible; MSIE 7.0; " + - "Windows NT 5.1; GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; " + - ".NET CLR 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR " + - "3.5.30729; Release=ARP)\" \"UD-1\" - \"image/jpeg\" \"whatever\" 0.350 \"-\" - \"\" 265 923 934 \"\" " + - "62.24.11.25 images.com 1358492167 - Whatup", - "10.10.10.10 - \"FRED\" [18/Jan/2013:18:02:37 +1100] \"GET http://images.com/2013/Generic.jpg " + - "HTTP/1.1\" 304 306 \"http:/referall.com\" \"Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; " + - "GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR " + - "3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR " + - "3.5.30729; Release=ARP)\" \"UD-1\" - \"image/jpeg\" \"whatever\" 0.352 \"-\" - \"\" 256 977 988 \"\" " + - "0 73.23.2.15 images.com 1358492557 - Whatup"); - - public static Pattern apacheLogRegex = Pattern.compile( - "^([\\d.]+) (\\S+) (\\S+) \\[([\\w\\d:/]+\\s[+\\-]\\d{4})\\] \"(.+?)\" (\\d{3}) ([\\d\\-]+) \"([^\"]+)\" \"([^\"]+)\".*"); - - /** Tracks the total query count and number of aggregate bytes for a particular group. */ - public static class Stats implements Serializable { - - private int count; - private int numBytes; - - public Stats(int count, int numBytes) { - this.count = count; - this.numBytes = numBytes; - } - public Stats merge(Stats other) { - return new Stats(count + other.count, numBytes + other.numBytes); - } - - public String toString() { - return String.format("bytes=%s\tn=%s", numBytes, count); - } - } - - public static Tuple3 extractKey(String line) { - Matcher m = apacheLogRegex.matcher(line); - List key = Collections.emptyList(); - if (m.find()) { - String ip = m.group(1); - String user = m.group(3); - String query = m.group(5); - if (!user.equalsIgnoreCase("-")) { - return new Tuple3(ip, user, query); - } - } - return new Tuple3(null, null, null); - } - - public static Stats extractStats(String line) { - Matcher m = apacheLogRegex.matcher(line); - if (m.find()) { - int bytes = Integer.parseInt(m.group(7)); - return new Stats(1, bytes); - } - else - return new Stats(1, 0); - } - - public static void main(String[] args) throws Exception { - if (args.length == 0) { - System.err.println("Usage: JavaLogQuery [logFile]"); - System.exit(1); - } - - JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaLogQuery", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - JavaRDD dataSet = (args.length == 2) ? jsc.textFile(args[1]) : jsc.parallelize(exampleApacheLogs); - - JavaPairRDD, Stats> extracted = dataSet.map(new PairFunction, Stats>() { - @Override - public Tuple2, Stats> call(String s) throws Exception { - return new Tuple2, Stats>(extractKey(s), extractStats(s)); - } - }); - - JavaPairRDD, Stats> counts = extracted.reduceByKey(new Function2() { - @Override - public Stats call(Stats stats, Stats stats2) throws Exception { - return stats.merge(stats2); - } - }); - - List, Stats>> output = counts.collect(); - for (Tuple2 t : output) { - System.out.println(t._1 + "\t" + t._2); - } - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/examples/JavaPageRank.java b/examples/src/main/java/spark/examples/JavaPageRank.java deleted file mode 100644 index 75df1af2e3..0000000000 --- a/examples/src/main/java/spark/examples/JavaPageRank.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import scala.Tuple2; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.FlatMapFunction; -import spark.api.java.function.Function; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFlatMapFunction; -import spark.api.java.function.PairFunction; - -import java.util.List; -import java.util.ArrayList; - -/** - * Computes the PageRank of URLs from an input file. Input file should - * be in format of: - * URL neighbor URL - * URL neighbor URL - * URL neighbor URL - * ... - * where URL and their neighbors are separated by space(s). - */ -public class JavaPageRank { - private static class Sum extends Function2 { - @Override - public Double call(Double a, Double b) { - return a + b; - } - } - - public static void main(String[] args) throws Exception { - if (args.length < 3) { - System.err.println("Usage: JavaPageRank "); - System.exit(1); - } - - JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaPageRank", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - // Loads in input file. It should be in format of: - // URL neighbor URL - // URL neighbor URL - // URL neighbor URL - // ... - JavaRDD lines = ctx.textFile(args[1], 1); - - // Loads all URLs from input file and initialize their neighbors. - JavaPairRDD> links = lines.map(new PairFunction() { - @Override - public Tuple2 call(String s) { - String[] parts = s.split("\\s+"); - return new Tuple2(parts[0], parts[1]); - } - }).distinct().groupByKey().cache(); - - // Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. - JavaPairRDD ranks = links.mapValues(new Function, Double>() { - @Override - public Double call(List rs) throws Exception { - return 1.0; - } - }); - - // Calculates and updates URL ranks continuously using PageRank algorithm. - for (int current = 0; current < Integer.parseInt(args[2]); current++) { - // Calculates URL contributions to the rank of other URLs. - JavaPairRDD contribs = links.join(ranks).values() - .flatMap(new PairFlatMapFunction, Double>, String, Double>() { - @Override - public Iterable> call(Tuple2, Double> s) { - List> results = new ArrayList>(); - for (String n : s._1) { - results.add(new Tuple2(n, s._2 / s._1.size())); - } - return results; - } - }); - - // Re-calculates URL ranks based on neighbor contributions. - ranks = contribs.reduceByKey(new Sum()).mapValues(new Function() { - @Override - public Double call(Double sum) throws Exception { - return 0.15 + sum * 0.85; - } - }); - } - - // Collects all URL ranks and dump them to console. - List> output = ranks.collect(); - for (Tuple2 tuple : output) { - System.out.println(tuple._1 + " has rank: " + tuple._2 + "."); - } - - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/examples/JavaSparkPi.java b/examples/src/main/java/spark/examples/JavaSparkPi.java deleted file mode 100644 index d5f42fbb38..0000000000 --- a/examples/src/main/java/spark/examples/JavaSparkPi.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; -import spark.api.java.function.Function2; - -import java.util.ArrayList; -import java.util.List; - -/** Computes an approximation to pi */ -public class JavaSparkPi { - - - public static void main(String[] args) throws Exception { - if (args.length == 0) { - System.err.println("Usage: JavaLogQuery [slices]"); - System.exit(1); - } - - JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaLogQuery", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - int slices = (args.length == 2) ? Integer.parseInt(args[1]) : 2; - int n = 100000 * slices; - List l = new ArrayList(n); - for (int i = 0; i < n; i++) - l.add(i); - - JavaRDD dataSet = jsc.parallelize(l, slices); - - int count = dataSet.map(new Function() { - @Override - public Integer call(Integer integer) throws Exception { - double x = Math.random() * 2 - 1; - double y = Math.random() * 2 - 1; - return (x * x + y * y < 1) ? 1 : 0; - } - }).reduce(new Function2() { - @Override - public Integer call(Integer integer, Integer integer2) throws Exception { - return integer + integer2; - } - }); - - System.out.println("Pi is roughly " + 4.0 * count / n); - } -} diff --git a/examples/src/main/java/spark/examples/JavaTC.java b/examples/src/main/java/spark/examples/JavaTC.java deleted file mode 100644 index 559d7f9e53..0000000000 --- a/examples/src/main/java/spark/examples/JavaTC.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import scala.Tuple2; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.PairFunction; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Random; -import java.util.Set; - -/** - * Transitive closure on a graph, implemented in Java. - */ -public class JavaTC { - - static int numEdges = 200; - static int numVertices = 100; - static Random rand = new Random(42); - - static List> generateGraph() { - Set> edges = new HashSet>(numEdges); - while (edges.size() < numEdges) { - int from = rand.nextInt(numVertices); - int to = rand.nextInt(numVertices); - Tuple2 e = new Tuple2(from, to); - if (from != to) edges.add(e); - } - return new ArrayList>(edges); - } - - static class ProjectFn extends PairFunction>, - Integer, Integer> { - static ProjectFn INSTANCE = new ProjectFn(); - - public Tuple2 call(Tuple2> triple) { - return new Tuple2(triple._2()._2(), triple._2()._1()); - } - } - - public static void main(String[] args) { - if (args.length == 0) { - System.err.println("Usage: JavaTC []"); - System.exit(1); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaTC", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - Integer slices = (args.length > 1) ? Integer.parseInt(args[1]): 2; - JavaPairRDD tc = sc.parallelizePairs(generateGraph(), slices).cache(); - - // Linear transitive closure: each round grows paths by one edge, - // by joining the graph's edges with the already-discovered paths. - // e.g. join the path (y, z) from the TC with the edge (x, y) from - // the graph to obtain the path (x, z). - - // Because join() joins on keys, the edges are stored in reversed order. - JavaPairRDD edges = tc.map( - new PairFunction, Integer, Integer>() { - public Tuple2 call(Tuple2 e) { - return new Tuple2(e._2(), e._1()); - } - }); - - long oldCount = 0; - long nextCount = tc.count(); - do { - oldCount = nextCount; - // Perform the join, obtaining an RDD of (y, (z, x)) pairs, - // then project the result to obtain the new (x, z) paths. - tc = tc.union(tc.join(edges).map(ProjectFn.INSTANCE)).distinct().cache(); - nextCount = tc.count(); - } while (nextCount != oldCount); - - System.out.println("TC has " + tc.count() + " edges."); - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/examples/JavaWordCount.java b/examples/src/main/java/spark/examples/JavaWordCount.java deleted file mode 100644 index 1af370c1c3..0000000000 --- a/examples/src/main/java/spark/examples/JavaWordCount.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import scala.Tuple2; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.FlatMapFunction; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFunction; - -import java.util.Arrays; -import java.util.List; - -public class JavaWordCount { - public static void main(String[] args) throws Exception { - if (args.length < 2) { - System.err.println("Usage: JavaWordCount "); - System.exit(1); - } - - JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaWordCount", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - JavaRDD lines = ctx.textFile(args[1], 1); - - JavaRDD words = lines.flatMap(new FlatMapFunction() { - public Iterable call(String s) { - return Arrays.asList(s.split(" ")); - } - }); - - JavaPairRDD ones = words.map(new PairFunction() { - public Tuple2 call(String s) { - return new Tuple2(s, 1); - } - }); - - JavaPairRDD counts = ones.reduceByKey(new Function2() { - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); - - List> output = counts.collect(); - for (Tuple2 tuple : output) { - System.out.println(tuple._1 + ": " + tuple._2); - } - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/mllib/examples/JavaALS.java b/examples/src/main/java/spark/mllib/examples/JavaALS.java deleted file mode 100644 index b48f459cb7..0000000000 --- a/examples/src/main/java/spark/mllib/examples/JavaALS.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.examples; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; - -import spark.mllib.recommendation.ALS; -import spark.mllib.recommendation.MatrixFactorizationModel; -import spark.mllib.recommendation.Rating; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.StringTokenizer; - -import scala.Tuple2; - -/** - * Example using MLLib ALS from Java. - */ -public class JavaALS { - - static class ParseRating extends Function { - public Rating call(String line) { - StringTokenizer tok = new StringTokenizer(line, ","); - int x = Integer.parseInt(tok.nextToken()); - int y = Integer.parseInt(tok.nextToken()); - double rating = Double.parseDouble(tok.nextToken()); - return new Rating(x, y, rating); - } - } - - static class FeaturesToString extends Function, String> { - public String call(Tuple2 element) { - return element._1().toString() + "," + Arrays.toString(element._2()); - } - } - - public static void main(String[] args) { - - if (args.length != 5 && args.length != 6) { - System.err.println( - "Usage: JavaALS []"); - System.exit(1); - } - - int rank = Integer.parseInt(args[2]); - int iterations = Integer.parseInt(args[3]); - String outputDir = args[4]; - int blocks = -1; - if (args.length == 6) { - blocks = Integer.parseInt(args[5]); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - JavaRDD lines = sc.textFile(args[1]); - - JavaRDD ratings = lines.map(new ParseRating()); - - MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks); - - model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( - outputDir + "/userFeatures"); - model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( - outputDir + "/productFeatures"); - System.out.println("Final user/product features written to " + outputDir); - - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/mllib/examples/JavaKMeans.java b/examples/src/main/java/spark/mllib/examples/JavaKMeans.java deleted file mode 100644 index 02f40438b8..0000000000 --- a/examples/src/main/java/spark/mllib/examples/JavaKMeans.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.examples; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; - -import spark.mllib.clustering.KMeans; -import spark.mllib.clustering.KMeansModel; - -import java.util.Arrays; -import java.util.StringTokenizer; - -/** - * Example using MLLib KMeans from Java. - */ -public class JavaKMeans { - - static class ParsePoint extends Function { - public double[] call(String line) { - StringTokenizer tok = new StringTokenizer(line, " "); - int numTokens = tok.countTokens(); - double[] point = new double[numTokens]; - for (int i = 0; i < numTokens; ++i) { - point[i] = Double.parseDouble(tok.nextToken()); - } - return point; - } - } - - public static void main(String[] args) { - - if (args.length < 4) { - System.err.println( - "Usage: JavaKMeans []"); - System.exit(1); - } - - String inputFile = args[1]; - int k = Integer.parseInt(args[2]); - int iterations = Integer.parseInt(args[3]); - int runs = 1; - - if (args.length >= 5) { - runs = Integer.parseInt(args[4]); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - JavaRDD lines = sc.textFile(args[1]); - - JavaRDD points = lines.map(new ParsePoint()); - - KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs); - - System.out.println("Cluster centers:"); - for (double[] center : model.clusterCenters()) { - System.out.println(" " + Arrays.toString(center)); - } - double cost = model.computeCost(points.rdd()); - System.out.println("Cost: " + cost); - - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/mllib/examples/JavaLR.java b/examples/src/main/java/spark/mllib/examples/JavaLR.java deleted file mode 100644 index bf4aeaf40f..0000000000 --- a/examples/src/main/java/spark/mllib/examples/JavaLR.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.examples; - - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; - -import spark.mllib.classification.LogisticRegressionWithSGD; -import spark.mllib.classification.LogisticRegressionModel; -import spark.mllib.regression.LabeledPoint; - -import java.util.Arrays; -import java.util.StringTokenizer; - -/** - * Logistic regression based classification using ML Lib. - */ -public class JavaLR { - - static class ParsePoint extends Function { - public LabeledPoint call(String line) { - String[] parts = line.split(","); - double y = Double.parseDouble(parts[0]); - StringTokenizer tok = new StringTokenizer(parts[1], " "); - int numTokens = tok.countTokens(); - double[] x = new double[numTokens]; - for (int i = 0; i < numTokens; ++i) { - x[i] = Double.parseDouble(tok.nextToken()); - } - return new LabeledPoint(y, x); - } - } - - public static void printWeights(double[] a) { - System.out.println(Arrays.toString(a)); - } - - public static void main(String[] args) { - if (args.length != 4) { - System.err.println("Usage: JavaLR "); - System.exit(1); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaLR", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - JavaRDD lines = sc.textFile(args[1]); - JavaRDD points = lines.map(new ParsePoint()).cache(); - double stepSize = Double.parseDouble(args[2]); - int iterations = Integer.parseInt(args[3]); - - // Another way to configure LogisticRegression - // - // LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD(); - // lr.optimizer().setNumIterations(iterations) - // .setStepSize(stepSize) - // .setMiniBatchFraction(1.0); - // lr.setIntercept(true); - // LogisticRegressionModel model = lr.train(points.rdd()); - - LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), - iterations, stepSize); - - System.out.print("Final w: "); - printWeights(model.weights()); - - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java deleted file mode 100644 index 096a9ae219..0000000000 --- a/examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples; - -import spark.api.java.function.Function; -import spark.streaming.*; -import spark.streaming.api.java.*; -import spark.streaming.dstream.SparkFlumeEvent; - -/** - * Produces a count of events received from Flume. - * - * This should be used in conjunction with an AvroSink in Flume. It will start - * an Avro server on at the request host:port address and listen for requests. - * Your Flume AvroSink should be pointed to this address. - * - * Usage: JavaFlumeEventCount - * - * is a Spark master URL - * is the host the Flume receiver will be started on - a receiver - * creates a server and listens for flume events. - * is the port the Flume receiver will listen on. - */ -public class JavaFlumeEventCount { - public static void main(String[] args) { - if (args.length != 3) { - System.err.println("Usage: JavaFlumeEventCount "); - System.exit(1); - } - - String master = args[0]; - String host = args[1]; - int port = Integer.parseInt(args[2]); - - Duration batchInterval = new Duration(2000); - - JavaStreamingContext sc = new JavaStreamingContext(master, "FlumeEventCount", batchInterval, - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - JavaDStream flumeStream = sc.flumeStream("localhost", port); - - flumeStream.count(); - - flumeStream.count().map(new Function() { - @Override - public String call(Long in) { - return "Received " + in + " flume events."; - } - }).print(); - - sc.start(); - } -} diff --git a/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java deleted file mode 100644 index c54d3f3d59..0000000000 --- a/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples; - -import com.google.common.collect.Lists; -import scala.Tuple2; -import spark.api.java.function.FlatMapFunction; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFunction; -import spark.streaming.Duration; -import spark.streaming.api.java.JavaDStream; -import spark.streaming.api.java.JavaPairDStream; -import spark.streaming.api.java.JavaStreamingContext; - -/** - * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. - * Usage: NetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the TCP server that Spark Streaming would connect to receive data. - * - * To run this on your local machine, you need to first run a Netcat server - * `$ nc -lk 9999` - * and then run the example - * `$ ./run spark.streaming.examples.JavaNetworkWordCount local[2] localhost 9999` - */ -public class JavaNetworkWordCount { - public static void main(String[] args) { - if (args.length < 3) { - System.err.println("Usage: NetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1"); - System.exit(1); - } - - // Create the context with a 1 second batch size - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount", - new Duration(1000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - JavaDStream lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2])); - JavaDStream words = lines.flatMap(new FlatMapFunction() { - @Override - public Iterable call(String x) { - return Lists.newArrayList(x.split(" ")); - } - }); - JavaPairDStream wordCounts = words.map( - new PairFunction() { - @Override - public Tuple2 call(String s) throws Exception { - return new Tuple2(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 + i2; - } - }); - - wordCounts.print(); - ssc.start(); - - } -} diff --git a/examples/src/main/java/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/java/spark/streaming/examples/JavaQueueStream.java deleted file mode 100644 index 1f4a991542..0000000000 --- a/examples/src/main/java/spark/streaming/examples/JavaQueueStream.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples; - -import com.google.common.collect.Lists; -import scala.Tuple2; -import spark.api.java.JavaRDD; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFunction; -import spark.streaming.Duration; -import spark.streaming.api.java.JavaDStream; -import spark.streaming.api.java.JavaPairDStream; -import spark.streaming.api.java.JavaStreamingContext; - -import java.util.LinkedList; -import java.util.List; -import java.util.Queue; - -public class JavaQueueStream { - public static void main(String[] args) throws InterruptedException { - if (args.length < 1) { - System.err.println("Usage: JavaQueueStream "); - System.exit(1); - } - - // Create the context - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "QueueStream", new Duration(1000), - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - // Create the queue through which RDDs can be pushed to - // a QueueInputDStream - Queue> rddQueue = new LinkedList>(); - - // Create and push some RDDs into the queue - List list = Lists.newArrayList(); - for (int i = 0; i < 1000; i++) { - list.add(i); - } - - for (int i = 0; i < 30; i++) { - rddQueue.add(ssc.sc().parallelize(list)); - } - - - // Create the QueueInputDStream and use it do some processing - JavaDStream inputStream = ssc.queueStream(rddQueue); - JavaPairDStream mappedStream = inputStream.map( - new PairFunction() { - @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i % 10, 1); - } - }); - JavaPairDStream reducedStream = mappedStream.reduceByKey( - new Function2() { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 + i2; - } - }); - - reducedStream.print(); - ssc.start(); - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala new file mode 100644 index 0000000000..868ff81f67 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext + +object BroadcastTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: BroadcastTest [] [numElem]") + System.exit(1) + } + + val sc = new SparkContext(args(0), "Broadcast Test", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val slices = if (args.length > 1) args(1).toInt else 2 + val num = if (args.length > 2) args(2).toInt else 1000000 + + var arr1 = new Array[Int](num) + for (i <- 0 until arr1.length) { + arr1(i) = i + } + + for (i <- 0 until 2) { + println("Iteration " + i) + println("===========") + val barr1 = sc.broadcast(arr1) + sc.parallelize(1 to 10, slices).foreach { + i => println(barr1.value.size) + } + } + + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala new file mode 100644 index 0000000000..33bf7151a7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.hadoop.mapreduce.Job +import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat +import org.apache.cassandra.hadoop.ConfigHelper +import org.apache.cassandra.hadoop.ColumnFamilyInputFormat +import org.apache.cassandra.thrift._ +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import java.nio.ByteBuffer +import java.util.SortedMap +import org.apache.cassandra.db.IColumn +import org.apache.cassandra.utils.ByteBufferUtil +import scala.collection.JavaConversions._ + + +/* + * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra + * support for Hadoop. + * + * To run this example, run this file with the following command params - + * + * + * So if you want to run this on localhost this will be, + * local[3] localhost 9160 + * + * The example makes some assumptions: + * 1. You have already created a keyspace called casDemo and it has a column family named Words + * 2. There are column family has a column named "para" which has test content. + * + * You can create the content by running the following script at the bottom of this file with + * cassandra-cli. + * + */ +object CassandraTest { + + def main(args: Array[String]) { + + // Get a SparkContext + val sc = new SparkContext(args(0), "casDemo") + + // Build the job configuration with ConfigHelper provided by Cassandra + val job = new Job() + job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) + + val host: String = args(1) + val port: String = args(2) + + ConfigHelper.setInputInitialAddress(job.getConfiguration(), host) + ConfigHelper.setInputRpcPort(job.getConfiguration(), port) + ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host) + ConfigHelper.setOutputRpcPort(job.getConfiguration(), port) + ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") + ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount") + + val predicate = new SlicePredicate() + val sliceRange = new SliceRange() + sliceRange.setStart(Array.empty[Byte]) + sliceRange.setFinish(Array.empty[Byte]) + predicate.setSlice_range(sliceRange) + ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate) + + ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + + // Make a new Hadoop RDD + val casRdd = sc.newAPIHadoopRDD( + job.getConfiguration(), + classOf[ColumnFamilyInputFormat], + classOf[ByteBuffer], + classOf[SortedMap[ByteBuffer, IColumn]]) + + // Let us first get all the paragraphs from the retrieved rows + val paraRdd = casRdd.map { + case (key, value) => { + ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) + } + } + + // Lets get the word count in paras + val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) + + counts.collect().foreach { + case (word, count) => println(word + ":" + count) + } + + counts.map { + case (word, count) => { + val colWord = new org.apache.cassandra.thrift.Column() + colWord.setName(ByteBufferUtil.bytes("word")) + colWord.setValue(ByteBufferUtil.bytes(word)) + colWord.setTimestamp(System.currentTimeMillis) + + val colCount = new org.apache.cassandra.thrift.Column() + colCount.setName(ByteBufferUtil.bytes("wcount")) + colCount.setValue(ByteBufferUtil.bytes(count.toLong)) + colCount.setTimestamp(System.currentTimeMillis) + + val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) + + val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil + mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) + mutations.get(0).column_or_supercolumn.setColumn(colWord) + mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) + mutations.get(1).column_or_supercolumn.setColumn(colCount) + (outputkey, mutations) + } + }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], + classOf[ColumnFamilyOutputFormat], job.getConfiguration) + } +} + +/* +create keyspace casDemo; +use casDemo; + +create column family WordCount with comparator = UTF8Type; +update column family WordCount with column_metadata = + [{column_name: word, validation_class: UTF8Type}, + {column_name: wcount, validation_class: LongType}]; + +create column family Words with comparator = UTF8Type; +update column family Words with column_metadata = + [{column_name: book, validation_class: UTF8Type}, + {column_name: para, validation_class: UTF8Type}]; + +assume Words keys as utf8; + +set Words['3musk001']['book'] = 'The Three Musketeers'; +set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market + town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to + be in as perfect a state of revolution as if the Huguenots had just made + a second La Rochelle of it. Many citizens, seeing the women flying + toward the High Street, leaving their children crying at the open doors, + hastened to don the cuirass, and supporting their somewhat uncertain + courage with a musket or a partisan, directed their steps toward the + hostelry of the Jolly Miller, before which was gathered, increasing + every minute, a compact group, vociferous and full of curiosity.'; + +set Words['3musk002']['book'] = 'The Three Musketeers'; +set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without + some city or other registering in its archives an event of this kind. There were + nobles, who made war against each other; there was the king, who made + war against the cardinal; there was Spain, which made war against the + king. Then, in addition to these concealed or public, secret or open + wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels, + who made war upon everybody. The citizens always took up arms readily + against thieves, wolves or scoundrels, often against nobles or + Huguenots, sometimes against the king, but never against cardinal or + Spain. It resulted, then, from this habit that on the said first Monday + of April, 1625, the citizens, on hearing the clamor, and seeing neither + the red-and-yellow standard nor the livery of the Duc de Richelieu, + rushed toward the hostel of the Jolly Miller. When arrived there, the + cause of the hubbub was apparent to all'; + +set Words['3musk003']['book'] = 'The Three Musketeers'; +set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however + large the sum may be; but you ought also to endeavor to perfect yourself in + the exercises becoming a gentleman. I will write a letter today to the + Director of the Royal Academy, and tomorrow he will admit you without + any expense to yourself. Do not refuse this little service. Our + best-born and richest gentlemen sometimes solicit it without being able + to obtain it. You will learn horsemanship, swordsmanship in all its + branches, and dancing. You will make some desirable acquaintances; and + from time to time you can call upon me, just to tell me how you are + getting on, and to say whether I can be of further service to you.'; + + +set Words['thelostworld001']['book'] = 'The Lost World'; +set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined + against the red curtain. How beautiful she was! And yet how aloof! We had been + friends, quite good friends; but never could I get beyond the same + comradeship which I might have established with one of my + fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly, + and perfectly unsexual. My instincts are all against a woman being too + frank and at her ease with me. It is no compliment to a man. Where + the real sex feeling begins, timidity and distrust are its companions, + heritage from old wicked days when love and violence went often hand in + hand. The bent head, the averted eye, the faltering voice, the wincing + figure--these, and not the unshrinking gaze and frank reply, are the + true signals of passion. Even in my short life I had learned as much + as that--or had inherited it in that race memory which we call instinct.'; + +set Words['thelostworld002']['book'] = 'The Lost World'; +set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, + red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was + the real boss; but he lived in the rarefied atmosphere of some Olympian + height from which he could distinguish nothing smaller than an + international crisis or a split in the Cabinet. Sometimes we saw him + passing in lonely majesty to his inner sanctum, with his eyes staring + vaguely and his mind hovering over the Balkans or the Persian Gulf. He + was above and beyond us. But McArdle was his first lieutenant, and it + was he that we knew. The old man nodded as I entered the room, and he + pushed his spectacles far up on his bald forehead.'; + +*/ diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala new file mode 100644 index 0000000000..92eb96bd8e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext + +object ExceptionHandlingTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: ExceptionHandlingTest ") + System.exit(1) + } + + val sc = new SparkContext(args(0), "ExceptionHandlingTest", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + sc.parallelize(0 until sc.defaultParallelism).foreach { i => + if (math.random > 0.75) + throw new Exception("Testing exception handling") + } + + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala new file mode 100644 index 0000000000..42c2e0e8e1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import java.util.Random + +object GroupByTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") + System.exit(1) + } + + var numMappers = if (args.length > 1) args(1).toInt else 2 + var numKVPairs = if (args.length > 2) args(2).toInt else 1000 + var valSize = if (args.length > 3) args(3).toInt else 1000 + var numReducers = if (args.length > 4) args(4).toInt else numMappers + + val sc = new SparkContext(args(0), "GroupBy Test", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val ranGen = new Random + var arr1 = new Array[(Int, Array[Byte])](numKVPairs) + for (i <- 0 until numKVPairs) { + val byteArr = new Array[Byte](valSize) + ranGen.nextBytes(byteArr) + arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) + } + arr1 + }.cache + // Enforce that everything has been calculated and in cache + pairs1.count + + println(pairs1.groupByKey(numReducers).count) + + System.exit(0) + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala new file mode 100644 index 0000000000..efe2e93b0d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark._ +import org.apache.spark.rdd.NewHadoopRDD +import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} +import org.apache.hadoop.hbase.client.HBaseAdmin +import org.apache.hadoop.hbase.mapreduce.TableInputFormat + +object HBaseTest { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "HBaseTest", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val conf = HBaseConfiguration.create() + + // Other options for configuring scan behavior are available. More information available at + // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html + conf.set(TableInputFormat.INPUT_TABLE, args(1)) + + // Initialize hBase table if necessary + val admin = new HBaseAdmin(conf) + if(!admin.isTableAvailable(args(1))) { + val tableDesc = new HTableDescriptor(args(1)) + admin.createTable(tableDesc) + } + + val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], + classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], + classOf[org.apache.hadoop.hbase.client.Result]) + + hBaseRDD.count() + + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala new file mode 100644 index 0000000000..d6a88d3032 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark._ + +object HdfsTest { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "HdfsTest", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val file = sc.textFile(args(1)) + val mapped = file.map(s => s.length).cache() + for (iter <- 1 to 10) { + val start = System.currentTimeMillis() + for (x <- mapped) { x + 2 } + // println("Processing: " + x) + val end = System.currentTimeMillis() + println("Iteration " + iter + " took " + (end-start) + " ms") + } + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala new file mode 100644 index 0000000000..4af45b2b4a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import scala.math.sqrt +import cern.jet.math._ +import cern.colt.matrix._ +import cern.colt.matrix.linalg._ + +/** + * Alternating least squares matrix factorization. + */ +object LocalALS { + // Parameters set through command line arguments + var M = 0 // Number of movies + var U = 0 // Number of users + var F = 0 // Number of features + var ITERATIONS = 0 + + val LAMBDA = 0.01 // Regularization coefficient + + // Some COLT objects + val factory2D = DoubleFactory2D.dense + val factory1D = DoubleFactory1D.dense + val algebra = Algebra.DEFAULT + val blas = SeqBlas.seqBlas + + def generateR(): DoubleMatrix2D = { + val mh = factory2D.random(M, F) + val uh = factory2D.random(U, F) + return algebra.mult(mh, algebra.transpose(uh)) + } + + def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], + us: Array[DoubleMatrix1D]): Double = + { + val r = factory2D.make(M, U) + for (i <- 0 until M; j <- 0 until U) { + r.set(i, j, blas.ddot(ms(i), us(j))) + } + //println("R: " + r) + blas.daxpy(-1, targetR, r) + val sumSqs = r.aggregate(Functions.plus, Functions.square) + return sqrt(sumSqs / (M * U)) + } + + def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + R: DoubleMatrix2D) : DoubleMatrix1D = + { + val XtX = factory2D.make(F, F) + val Xty = factory1D.make(F) + // For each user that rated the movie + for (j <- 0 until U) { + val u = us(j) + // Add u * u^t to XtX + blas.dger(1, u, u, XtX) + // Add u * rating to Xty + blas.daxpy(R.get(i, j), u, Xty) + } + // Add regularization coefs to diagonal terms + for (d <- 0 until F) { + XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) + } + // Solve it with Cholesky + val ch = new CholeskyDecomposition(XtX) + val Xty2D = factory2D.make(Xty.toArray, F) + val solved2D = ch.solve(Xty2D) + return solved2D.viewColumn(0) + } + + def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], + R: DoubleMatrix2D) : DoubleMatrix1D = + { + val XtX = factory2D.make(F, F) + val Xty = factory1D.make(F) + // For each movie that the user rated + for (i <- 0 until M) { + val m = ms(i) + // Add m * m^t to XtX + blas.dger(1, m, m, XtX) + // Add m * rating to Xty + blas.daxpy(R.get(i, j), m, Xty) + } + // Add regularization coefs to diagonal terms + for (d <- 0 until F) { + XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) + } + // Solve it with Cholesky + val ch = new CholeskyDecomposition(XtX) + val Xty2D = factory2D.make(Xty.toArray, F) + val solved2D = ch.solve(Xty2D) + return solved2D.viewColumn(0) + } + + def main(args: Array[String]) { + args match { + case Array(m, u, f, iters) => { + M = m.toInt + U = u.toInt + F = f.toInt + ITERATIONS = iters.toInt + } + case _ => { + System.err.println("Usage: LocalALS ") + System.exit(1) + } + } + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); + + val R = generateR() + + // Initialize m and u randomly + var ms = Array.fill(M)(factory1D.random(F)) + var us = Array.fill(U)(factory1D.random(F)) + + // Iteratively update movies then users + for (iter <- 1 to ITERATIONS) { + println("Iteration " + iter + ":") + ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray + us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray + println("RMSE = " + rmse(R, ms, us)) + println() + } + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala new file mode 100644 index 0000000000..fb130ea198 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import org.apache.spark.util.Vector + +object LocalFileLR { + val D = 10 // Numer of dimensions + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def parsePoint(line: String): DataPoint = { + val nums = line.split(' ').map(_.toDouble) + return DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) + } + + def main(args: Array[String]) { + val lines = scala.io.Source.fromFile(args(0)).getLines().toArray + val points = lines.map(parsePoint _) + val ITERATIONS = args(1).toInt + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + var gradient = Vector.zeros(D) + for (p <- points) { + val scale = (1 / (1 + math.exp(-p.y * (w dot p.x))) - 1) * p.y + gradient += scale * p.x + } + w -= gradient + } + + println("Final w: " + w) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala new file mode 100644 index 0000000000..f90ea35cd4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import org.apache.spark.util.Vector +import org.apache.spark.SparkContext._ +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +/** + * K-means clustering. + */ +object LocalKMeans { + val N = 1000 + val R = 1000 // Scaling factor + val D = 10 + val K = 10 + val convergeDist = 0.001 + val rand = new Random(42) + + def generateData = { + def generatePoint(i: Int) = { + Vector(D, _ => rand.nextDouble * R) + } + Array.tabulate(N)(generatePoint) + } + + def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { + var index = 0 + var bestIndex = 0 + var closest = Double.PositiveInfinity + + for (i <- 1 to centers.size) { + val vCurr = centers.get(i).get + val tempDist = p.squaredDist(vCurr) + if (tempDist < closest) { + closest = tempDist + bestIndex = i + } + } + + return bestIndex + } + + def main(args: Array[String]) { + val data = generateData + var points = new HashSet[Vector] + var kPoints = new HashMap[Int, Vector] + var tempDist = 1.0 + + while (points.size < K) { + points.add(data(rand.nextInt(N))) + } + + val iter = points.iterator + for (i <- 1 to points.size) { + kPoints.put(i, iter.next()) + } + + println("Initial centers: " + kPoints) + + while(tempDist > convergeDist) { + var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) + + var mappings = closest.groupBy[Int] (x => x._1) + + var pointStats = mappings.map(pair => pair._2.reduceLeft [(Int, (Vector, Int))] {case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1+y2))}) + + var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)} + + tempDist = 0.0 + for (mapping <- newPoints) { + tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2) + } + + for (newP <- newPoints) { + kPoints.put(newP._1, newP._2) + } + } + + println("Final centers: " + kPoints) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala new file mode 100644 index 0000000000..cd4e9f1af0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import org.apache.spark.util.Vector + +/** + * Logistic regression based classification. + */ +object LocalLR { + val N = 10000 // Number of data points + val D = 10 // Number of dimensions + val R = 0.7 // Scaling factor + val ITERATIONS = 5 + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def generateData = { + def generatePoint(i: Int) = { + val y = if(i % 2 == 0) -1 else 1 + val x = Vector(D, _ => rand.nextGaussian + y * R) + DataPoint(x, y) + } + Array.tabulate(N)(generatePoint) + } + + def main(args: Array[String]) { + val data = generateData + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + var gradient = Vector.zeros(D) + for (p <- data) { + val scale = (1 / (1 + math.exp(-p.y * (w dot p.x))) - 1) * p.y + gradient += scale * p.x + } + w -= gradient + } + + println("Final w: " + w) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala new file mode 100644 index 0000000000..bb7f22ec8d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import scala.math.random +import org.apache.spark._ +import SparkContext._ + +object LocalPi { + def main(args: Array[String]) { + var count = 0 + for (i <- 1 to 100000) { + val x = random * 2 - 1 + val y = random * 2 - 1 + if (x*x + y*y < 1) count += 1 + } + println("Pi is roughly " + 4 * count / 100000.0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala new file mode 100644 index 0000000000..17ff3ce764 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +/** + * Executes a roll up-style query against Apache logs. + */ +object LogQuery { + val exampleApacheLogs = List( + """10.10.10.10 - "FRED" [18/Jan/2013:17:56:07 +1100] "GET http://images.com/2013/Generic.jpg + | HTTP/1.1" 304 315 "http://referall.com/" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; + | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR + | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR + | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.350 "-" - "" 265 923 934 "" + | 62.24.11.25 images.com 1358492167 - Whatup""".stripMargin.replace("\n", ""), + """10.10.10.10 - "FRED" [18/Jan/2013:18:02:37 +1100] "GET http://images.com/2013/Generic.jpg + | HTTP/1.1" 304 306 "http:/referall.com" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; + | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR + | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR + | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.352 "-" - "" 256 977 988 "" + | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.replace("\n", "") + ) + + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: LogQuery [logFile]") + System.exit(1) + } + + val sc = new SparkContext(args(0), "Log Query", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val dataSet = + if (args.length == 2) sc.textFile(args(1)) + else sc.parallelize(exampleApacheLogs) + + val apacheLogRegex = + """^([\d.]+) (\S+) (\S+) \[([\w\d:/]+\s[+\-]\d{4})\] "(.+?)" (\d{3}) ([\d\-]+) "([^"]+)" "([^"]+)".*""".r + + /** Tracks the total query count and number of aggregate bytes for a particular group. */ + class Stats(val count: Int, val numBytes: Int) extends Serializable { + def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) + override def toString = "bytes=%s\tn=%s".format(numBytes, count) + } + + def extractKey(line: String): (String, String, String) = { + apacheLogRegex.findFirstIn(line) match { + case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) => + if (user != "\"-\"") (ip, user, query) + else (null, null, null) + case _ => (null, null, null) + } + } + + def extractStats(line: String): Stats = { + apacheLogRegex.findFirstIn(line) match { + case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) => + new Stats(1, bytes.toInt) + case _ => new Stats(1, 0) + } + } + + dataSet.map(line => (extractKey(line), extractStats(line))) + .reduceByKey((a, b) => a.merge(b)) + .collect().foreach{ + case (user, query) => println("%s\t%s".format(user, query))} + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala new file mode 100644 index 0000000000..f79f0142b8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext + +object MultiBroadcastTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: BroadcastTest [] [numElem]") + System.exit(1) + } + + val sc = new SparkContext(args(0), "Broadcast Test", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val slices = if (args.length > 1) args(1).toInt else 2 + val num = if (args.length > 2) args(2).toInt else 1000000 + + var arr1 = new Array[Int](num) + for (i <- 0 until arr1.length) { + arr1(i) = i + } + + var arr2 = new Array[Int](num) + for (i <- 0 until arr2.length) { + arr2(i) = i + } + + val barr1 = sc.broadcast(arr1) + val barr2 = sc.broadcast(arr2) + sc.parallelize(1 to 10, slices).foreach { + i => println(barr1.value.size + barr2.value.size) + } + + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala new file mode 100644 index 0000000000..37ddfb5db7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import java.util.Random + +object SimpleSkewedGroupByTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SimpleSkewedGroupByTest " + + "[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]") + System.exit(1) + } + + var numMappers = if (args.length > 1) args(1).toInt else 2 + var numKVPairs = if (args.length > 2) args(2).toInt else 1000 + var valSize = if (args.length > 3) args(3).toInt else 1000 + var numReducers = if (args.length > 4) args(4).toInt else numMappers + var ratio = if (args.length > 5) args(5).toInt else 5.0 + + val sc = new SparkContext(args(0), "GroupBy Test", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val ranGen = new Random + var result = new Array[(Int, Array[Byte])](numKVPairs) + for (i <- 0 until numKVPairs) { + val byteArr = new Array[Byte](valSize) + ranGen.nextBytes(byteArr) + val offset = ranGen.nextInt(1000) * numReducers + if (ranGen.nextDouble < ratio / (numReducers + ratio - 1)) { + // give ratio times higher chance of generating key 0 (for reducer 0) + result(i) = (offset, byteArr) + } else { + // generate a key for one of the other reducers + val key = 1 + ranGen.nextInt(numReducers-1) + offset + result(i) = (key, byteArr) + } + } + result + }.cache + // Enforce that everything has been calculated and in cache + pairs1.count + + println("RESULT: " + pairs1.groupByKey(numReducers).count) + // Print how many keys each reducer got (for debugging) + //println("RESULT: " + pairs1.groupByKey(numReducers) + // .map{case (k,v) => (k, v.size)} + // .collectAsMap) + + System.exit(0) + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala new file mode 100644 index 0000000000..9c954b2b5b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import java.util.Random + +object SkewedGroupByTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") + System.exit(1) + } + + var numMappers = if (args.length > 1) args(1).toInt else 2 + var numKVPairs = if (args.length > 2) args(2).toInt else 1000 + var valSize = if (args.length > 3) args(3).toInt else 1000 + var numReducers = if (args.length > 4) args(4).toInt else numMappers + + val sc = new SparkContext(args(0), "GroupBy Test", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val ranGen = new Random + + // map output sizes lineraly increase from the 1st to the last + numKVPairs = (1.0 * (p + 1) / numMappers * numKVPairs).toInt + + var arr1 = new Array[(Int, Array[Byte])](numKVPairs) + for (i <- 0 until numKVPairs) { + val byteArr = new Array[Byte](valSize) + ranGen.nextBytes(byteArr) + arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) + } + arr1 + }.cache() + // Enforce that everything has been calculated and in cache + pairs1.count() + + println(pairs1.groupByKey(numReducers).count()) + + System.exit(0) + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala new file mode 100644 index 0000000000..814944ba1c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import scala.math.sqrt +import cern.jet.math._ +import cern.colt.matrix._ +import cern.colt.matrix.linalg._ +import org.apache.spark._ + +/** + * Alternating least squares matrix factorization. + */ +object SparkALS { + // Parameters set through command line arguments + var M = 0 // Number of movies + var U = 0 // Number of users + var F = 0 // Number of features + var ITERATIONS = 0 + + val LAMBDA = 0.01 // Regularization coefficient + + // Some COLT objects + val factory2D = DoubleFactory2D.dense + val factory1D = DoubleFactory1D.dense + val algebra = Algebra.DEFAULT + val blas = SeqBlas.seqBlas + + def generateR(): DoubleMatrix2D = { + val mh = factory2D.random(M, F) + val uh = factory2D.random(U, F) + return algebra.mult(mh, algebra.transpose(uh)) + } + + def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], + us: Array[DoubleMatrix1D]): Double = + { + val r = factory2D.make(M, U) + for (i <- 0 until M; j <- 0 until U) { + r.set(i, j, blas.ddot(ms(i), us(j))) + } + //println("R: " + r) + blas.daxpy(-1, targetR, r) + val sumSqs = r.aggregate(Functions.plus, Functions.square) + return sqrt(sumSqs / (M * U)) + } + + def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + R: DoubleMatrix2D) : DoubleMatrix1D = + { + val U = us.size + val F = us(0).size + val XtX = factory2D.make(F, F) + val Xty = factory1D.make(F) + // For each user that rated the movie + for (j <- 0 until U) { + val u = us(j) + // Add u * u^t to XtX + blas.dger(1, u, u, XtX) + // Add u * rating to Xty + blas.daxpy(R.get(i, j), u, Xty) + } + // Add regularization coefs to diagonal terms + for (d <- 0 until F) { + XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) + } + // Solve it with Cholesky + val ch = new CholeskyDecomposition(XtX) + val Xty2D = factory2D.make(Xty.toArray, F) + val solved2D = ch.solve(Xty2D) + return solved2D.viewColumn(0) + } + + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkALS [ ]") + System.exit(1) + } + + var host = "" + var slices = 0 + + val options = (0 to 5).map(i => if (i < args.length) Some(args(i)) else None) + + options.toArray match { + case Array(host_, m, u, f, iters, slices_) => + host = host_.get + M = m.getOrElse("100").toInt + U = u.getOrElse("500").toInt + F = f.getOrElse("10").toInt + ITERATIONS = iters.getOrElse("5").toInt + slices = slices_.getOrElse("2").toInt + case _ => + System.err.println("Usage: SparkALS [ ]") + System.exit(1) + } + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) + + val sc = new SparkContext(host, "SparkALS", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val R = generateR() + + // Initialize m and u randomly + var ms = Array.fill(M)(factory1D.random(F)) + var us = Array.fill(U)(factory1D.random(F)) + + // Iteratively update movies then users + val Rc = sc.broadcast(R) + var msb = sc.broadcast(ms) + var usb = sc.broadcast(us) + for (iter <- 1 to ITERATIONS) { + println("Iteration " + iter + ":") + ms = sc.parallelize(0 until M, slices) + .map(i => update(i, msb.value(i), usb.value, Rc.value)) + .toArray + msb = sc.broadcast(ms) // Re-broadcast ms because it was updated + us = sc.parallelize(0 until U, slices) + .map(i => update(i, usb.value(i), msb.value, algebra.transpose(Rc.value))) + .toArray + usb = sc.broadcast(us) // Re-broadcast us because it was updated + println("RMSE = " + rmse(R, ms, us)) + println() + } + + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala new file mode 100644 index 0000000000..646682878f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import scala.math.exp +import org.apache.spark.util.Vector +import org.apache.spark._ +import org.apache.spark.scheduler.InputFormatInfo + +/** + * Logistic regression based classification. + */ +object SparkHdfsLR { + val D = 10 // Numer of dimensions + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def parsePoint(line: String): DataPoint = { + //val nums = line.split(' ').map(_.toDouble) + //return DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) + val tok = new java.util.StringTokenizer(line, " ") + var y = tok.nextToken.toDouble + var x = new Array[Double](D) + var i = 0 + while (i < D) { + x(i) = tok.nextToken.toDouble; i += 1 + } + return DataPoint(new Vector(x), y) + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: SparkHdfsLR ") + System.exit(1) + } + val inputPath = args(1) + val conf = SparkEnv.get.hadoop.newConfiguration() + val sc = new SparkContext(args(0), "SparkHdfsLR", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")), Map(), + InputFormatInfo.computePreferredLocations( + Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)))) + val lines = sc.textFile(inputPath) + val points = lines.map(parsePoint _).cache() + val ITERATIONS = args(2).toInt + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + val gradient = points.map { p => + (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x + }.reduce(_ + _) + w -= gradient + } + + println("Final w: " + w) + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala new file mode 100644 index 0000000000..f7bf75b4e5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import org.apache.spark.SparkContext +import org.apache.spark.util.Vector +import org.apache.spark.SparkContext._ +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +/** + * K-means clustering. + */ +object SparkKMeans { + val R = 1000 // Scaling factor + val rand = new Random(42) + + def parseVector(line: String): Vector = { + return new Vector(line.split(' ').map(_.toDouble)) + } + + def closestPoint(p: Vector, centers: Array[Vector]): Int = { + var index = 0 + var bestIndex = 0 + var closest = Double.PositiveInfinity + + for (i <- 0 until centers.length) { + val tempDist = p.squaredDist(centers(i)) + if (tempDist < closest) { + closest = tempDist + bestIndex = i + } + } + + return bestIndex + } + + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: SparkLocalKMeans ") + System.exit(1) + } + val sc = new SparkContext(args(0), "SparkLocalKMeans", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val lines = sc.textFile(args(1)) + val data = lines.map(parseVector _).cache() + val K = args(2).toInt + val convergeDist = args(3).toDouble + + var kPoints = data.takeSample(false, K, 42).toArray + var tempDist = 1.0 + + while(tempDist > convergeDist) { + var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) + + var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} + + var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap() + + tempDist = 0.0 + for (i <- 0 until K) { + tempDist += kPoints(i).squaredDist(newPoints(i)) + } + + for (newP <- newPoints) { + kPoints(newP._1) = newP._2 + } + println("Finished iteration (delta = " + tempDist + ")") + } + + println("Final centers:") + kPoints.foreach(println) + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala new file mode 100644 index 0000000000..9ed9fe4d76 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import scala.math.exp +import org.apache.spark.util.Vector +import org.apache.spark._ + +/** + * Logistic regression based classification. + */ +object SparkLR { + val N = 10000 // Number of data points + val D = 10 // Numer of dimensions + val R = 0.7 // Scaling factor + val ITERATIONS = 5 + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def generateData = { + def generatePoint(i: Int) = { + val y = if(i % 2 == 0) -1 else 1 + val x = Vector(D, _ => rand.nextGaussian + y * R) + DataPoint(x, y) + } + Array.tabulate(N)(generatePoint) + } + + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkLR []") + System.exit(1) + } + val sc = new SparkContext(args(0), "SparkLR", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val numSlices = if (args.length > 1) args(1).toInt else 2 + val points = sc.parallelize(generateData, numSlices).cache() + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + val gradient = points.map { p => + (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x + }.reduce(_ + _) + w -= gradient + } + + println("Final w: " + w) + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala new file mode 100644 index 0000000000..2721caf08b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -0,0 +1,46 @@ +package org.apache.spark.examples + +import org.apache.spark.SparkContext._ +import org.apache.spark.SparkContext + + +/** + * Computes the PageRank of URLs from an input file. Input file should + * be in format of: + * URL neighbor URL + * URL neighbor URL + * URL neighbor URL + * ... + * where URL and their neighbors are separated by space(s). + */ +object SparkPageRank { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: PageRank ") + System.exit(1) + } + var iters = args(2).toInt + val ctx = new SparkContext(args(0), "PageRank", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val lines = ctx.textFile(args(1), 1) + val links = lines.map{ s => + val parts = s.split("\\s+") + (parts(0), parts(1)) + }.distinct().groupByKey().cache() + var ranks = links.mapValues(v => 1.0) + + for (i <- 1 to iters) { + val contribs = links.join(ranks).values.flatMap{ case (urls, rank) => + val size = urls.size + urls.map(url => (url, rank / size)) + } + ranks = contribs.reduceByKey(_ + _).mapValues(0.15 + 0.85 * _) + } + + val output = ranks.collect() + output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + ".")) + + System.exit(0) + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala new file mode 100644 index 0000000000..5a2bc9b0d0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import scala.math.random +import org.apache.spark._ +import SparkContext._ + +/** Computes an approximation to pi */ +object SparkPi { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkPi []") + System.exit(1) + } + val spark = new SparkContext(args(0), "SparkPi", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val slices = if (args.length > 1) args(1).toInt else 2 + val n = 100000 * slices + val count = spark.parallelize(1 to n, slices).map { i => + val x = random * 2 - 1 + val y = random * 2 - 1 + if (x*x + y*y < 1) 1 else 0 + }.reduce(_ + _) + println("Pi is roughly " + 4.0 * count / n) + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala new file mode 100644 index 0000000000..5a7a9d1bd8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark._ +import SparkContext._ +import scala.util.Random +import scala.collection.mutable + +/** + * Transitive closure on a graph. + */ +object SparkTC { + val numEdges = 200 + val numVertices = 100 + val rand = new Random(42) + + def generateGraph = { + val edges: mutable.Set[(Int, Int)] = mutable.Set.empty + while (edges.size < numEdges) { + val from = rand.nextInt(numVertices) + val to = rand.nextInt(numVertices) + if (from != to) edges.+=((from, to)) + } + edges.toSeq + } + + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkTC []") + System.exit(1) + } + val spark = new SparkContext(args(0), "SparkTC", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val slices = if (args.length > 1) args(1).toInt else 2 + var tc = spark.parallelize(generateGraph, slices).cache() + + // Linear transitive closure: each round grows paths by one edge, + // by joining the graph's edges with the already-discovered paths. + // e.g. join the path (y, z) from the TC with the edge (x, y) from + // the graph to obtain the path (x, z). + + // Because join() joins on keys, the edges are stored in reversed order. + val edges = tc.map(x => (x._2, x._1)) + + // This join is iterated until a fixed point is reached. + var oldCount = 0L + var nextCount = tc.count() + do { + oldCount = nextCount + // Perform the join, obtaining an RDD of (y, (z, x)) pairs, + // then project the result to obtain the new (x, z) paths. + tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache(); + nextCount = tc.count() + } while (nextCount != oldCount) + + println("TC has " + tc.count() + " edges.") + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala new file mode 100644 index 0000000000..b190e83c4d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.bagel + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +import org.apache.spark.bagel._ +import org.apache.spark.bagel.Bagel._ + +import scala.collection.mutable.ArrayBuffer + +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} + +import com.esotericsoftware.kryo._ + +class PageRankUtils extends Serializable { + def computeWithCombiner(numVertices: Long, epsilon: Double)( + self: PRVertex, messageSum: Option[Double], superstep: Int + ): (PRVertex, Array[PRMessage]) = { + val newValue = messageSum match { + case Some(msgSum) if msgSum != 0 => + 0.15 / numVertices + 0.85 * msgSum + case _ => self.value + } + + val terminate = superstep >= 10 + + val outbox: Array[PRMessage] = + if (!terminate) + self.outEdges.map(targetId => + new PRMessage(targetId, newValue / self.outEdges.size)) + else + Array[PRMessage]() + + (new PRVertex(newValue, self.outEdges, !terminate), outbox) + } + + def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) = + computeWithCombiner(numVertices, epsilon)(self, messages match { + case Some(msgs) => Some(msgs.map(_.value).sum) + case None => None + }, superstep) +} + +class PRCombiner extends Combiner[PRMessage, Double] with Serializable { + def createCombiner(msg: PRMessage): Double = + msg.value + def mergeMsg(combiner: Double, msg: PRMessage): Double = + combiner + msg.value + def mergeCombiners(a: Double, b: Double): Double = + a + b +} + +class PRVertex() extends Vertex with Serializable { + var value: Double = _ + var outEdges: Array[String] = _ + var active: Boolean = _ + + def this(value: Double, outEdges: Array[String], active: Boolean = true) { + this() + this.value = value + this.outEdges = outEdges + this.active = active + } + + override def toString(): String = { + "PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString) + } +} + +class PRMessage() extends Message[String] with Serializable { + var targetId: String = _ + var value: Double = _ + + def this(targetId: String, value: Double) { + this() + this.targetId = targetId + this.value = value + } +} + +class PRKryoRegistrator extends KryoRegistrator { + def registerClasses(kryo: Kryo) { + kryo.register(classOf[PRVertex]) + kryo.register(classOf[PRMessage]) + } +} + +class CustomPartitioner(partitions: Int) extends Partitioner { + def numPartitions = partitions + + def getPartition(key: Any): Int = { + val hash = key match { + case k: Long => (k & 0x00000000FFFFFFFFL).toInt + case _ => key.hashCode + } + + val mod = key.hashCode % partitions + if (mod < 0) mod + partitions else mod + } + + override def equals(other: Any): Boolean = other match { + case c: CustomPartitioner => + c.numPartitions == numPartitions + case _ => false + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala new file mode 100644 index 0000000000..b1f606e48e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.bagel + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +import org.apache.spark.bagel._ +import org.apache.spark.bagel.Bagel._ + +import scala.xml.{XML,NodeSeq} + +/** + * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" + * files from there, which contains one line per wiki article in a tab-separated format + * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). + */ +object WikipediaPageRank { + def main(args: Array[String]) { + if (args.length < 5) { + System.err.println("Usage: WikipediaPageRank ") + System.exit(-1) + } + + System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) + + val inputFile = args(0) + val threshold = args(1).toDouble + val numPartitions = args(2).toInt + val host = args(3) + val usePartitioner = args(4).toBoolean + val sc = new SparkContext(host, "WikipediaPageRank") + + // Parse the Wikipedia page data into a graph + val input = sc.textFile(inputFile) + + println("Counting vertices...") + val numVertices = input.count() + println("Done counting vertices.") + + println("Parsing input file...") + var vertices = input.map(line => { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = links.map(link => new String(link.text)).toArray + val id = new String(title) + (id, new PRVertex(1.0 / numVertices, outEdges)) + }) + if (usePartitioner) + vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache + else + vertices = vertices.cache + println("Done parsing input file.") + + // Do the computation + val epsilon = 0.01 / numVertices + val messages = sc.parallelize(Array[(String, PRMessage)]()) + val utils = new PageRankUtils + val result = + Bagel.run( + sc, vertices, messages, combiner = new PRCombiner(), + numPartitions = numPartitions)( + utils.computeWithCombiner(numVertices, epsilon)) + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = + (result + .filter { case (id, vertex) => vertex.value >= threshold } + .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } + .collect.mkString) + println(top) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala new file mode 100644 index 0000000000..3bfa48eaf3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.bagel + +import org.apache.spark._ +import serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.SparkContext._ + +import org.apache.spark.bagel._ +import org.apache.spark.bagel.Bagel._ + +import scala.xml.{XML,NodeSeq} + +import scala.collection.mutable.ArrayBuffer + +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} +import java.nio.ByteBuffer + +object WikipediaPageRankStandalone { + def main(args: Array[String]) { + if (args.length < 5) { + System.err.println("Usage: WikipediaPageRankStandalone ") + System.exit(-1) + } + + System.setProperty("spark.serializer", "spark.bagel.examples.WPRSerializer") + + val inputFile = args(0) + val threshold = args(1).toDouble + val numIterations = args(2).toInt + val host = args(3) + val usePartitioner = args(4).toBoolean + val sc = new SparkContext(host, "WikipediaPageRankStandalone") + + val input = sc.textFile(inputFile) + val partitioner = new HashPartitioner(sc.defaultParallelism) + val links = + if (usePartitioner) + input.map(parseArticle _).partitionBy(partitioner).cache() + else + input.map(parseArticle _).cache() + val n = links.count() + val defaultRank = 1.0 / n + val a = 0.15 + + // Do the computation + val startTime = System.currentTimeMillis + val ranks = + pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, sc.defaultParallelism) + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = + (ranks + .filter { case (id, rank) => rank >= threshold } + .map { case (id, rank) => "%s\t%s\n".format(id, rank) } + .collect().mkString) + println(top) + + val time = (System.currentTimeMillis - startTime) / 1000.0 + println("Completed %d iterations in %f seconds: %f seconds per iteration" + .format(numIterations, time, time / numIterations)) + System.exit(0) + } + + def parseArticle(line: String): (String, Array[String]) = { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val id = new String(title) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = links.map(link => new String(link.text)).toArray + (id, outEdges) + } + + def pageRank( + links: RDD[(String, Array[String])], + numIterations: Int, + defaultRank: Double, + a: Double, + n: Long, + partitioner: Partitioner, + usePartitioner: Boolean, + numPartitions: Int + ): RDD[(String, Double)] = { + var ranks = links.mapValues { edges => defaultRank } + for (i <- 1 to numIterations) { + val contribs = links.groupWith(ranks).flatMap { + case (id, (linksWrapper, rankWrapper)) => + if (linksWrapper.length > 0) { + if (rankWrapper.length > 0) { + linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size)) + } else { + linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size)) + } + } else { + Array[(String, Double)]() + } + } + ranks = (contribs.combineByKey((x: Double) => x, + (x: Double, y: Double) => x + y, + (x: Double, y: Double) => x + y, + partitioner) + .mapValues(sum => a/n + (1-a)*sum)) + } + ranks + } +} + +class WPRSerializer extends org.apache.spark.serializer.Serializer { + def newInstance(): SerializerInstance = new WPRSerializerInstance() +} + +class WPRSerializerInstance extends SerializerInstance { + def serialize[T](t: T): ByteBuffer = { + throw new UnsupportedOperationException() + } + + def deserialize[T](bytes: ByteBuffer): T = { + throw new UnsupportedOperationException() + } + + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + throw new UnsupportedOperationException() + } + + def serializeStream(s: OutputStream): SerializationStream = { + new WPRSerializationStream(s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new WPRDeserializationStream(s) + } +} + +class WPRSerializationStream(os: OutputStream) extends SerializationStream { + val dos = new DataOutputStream(os) + + def writeObject[T](t: T): SerializationStream = t match { + case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { + case links: Array[String] => { + dos.writeInt(0) // links + dos.writeUTF(id) + dos.writeInt(links.length) + for (link <- links) { + dos.writeUTF(link) + } + this + } + case rank: Double => { + dos.writeInt(1) // rank + dos.writeUTF(id) + dos.writeDouble(rank) + this + } + } + case (id: String, rank: Double) => { + dos.writeInt(2) // rank without wrapper + dos.writeUTF(id) + dos.writeDouble(rank) + this + } + } + + def flush() { dos.flush() } + def close() { dos.close() } +} + +class WPRDeserializationStream(is: InputStream) extends DeserializationStream { + val dis = new DataInputStream(is) + + def readObject[T](): T = { + val typeId = dis.readInt() + typeId match { + case 0 => { + val id = dis.readUTF() + val numLinks = dis.readInt() + val links = new Array[String](numLinks) + for (i <- 0 until numLinks) { + val link = dis.readUTF() + links(i) = link + } + (id, ArrayBuffer(links)).asInstanceOf[T] + } + case 1 => { + val id = dis.readUTF() + val rank = dis.readDouble() + (id, ArrayBuffer(rank)).asInstanceOf[T] + } + case 2 => { + val id = dis.readUTF() + val rank = dis.readDouble() + (id, rank).asInstanceOf[T] + } + } + } + + def close() { dis.close() } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala new file mode 100644 index 0000000000..cd3423a07b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import scala.collection.mutable.LinkedList +import scala.util.Random + +import akka.actor.Actor +import akka.actor.ActorRef +import akka.actor.Props +import akka.actor.actorRef2Scala + +import org.apache.spark.streaming.Seconds +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions +import org.apache.spark.streaming.receivers.Receiver +import org.apache.spark.util.AkkaUtils + +case class SubscribeReceiver(receiverActor: ActorRef) +case class UnsubscribeReceiver(receiverActor: ActorRef) + +/** + * Sends the random content to every receiver subscribed with 1/2 + * second delay. + */ +class FeederActor extends Actor { + + val rand = new Random() + var receivers: LinkedList[ActorRef] = new LinkedList[ActorRef]() + + val strings: Array[String] = Array("words ", "may ", "count ") + + def makeMessage(): String = { + val x = rand.nextInt(3) + strings(x) + strings(2 - x) + } + + /* + * A thread to generate random messages + */ + new Thread() { + override def run() { + while (true) { + Thread.sleep(500) + receivers.foreach(_ ! makeMessage) + } + } + }.start() + + def receive: Receive = { + + case SubscribeReceiver(receiverActor: ActorRef) => + println("received subscribe from %s".format(receiverActor.toString)) + receivers = LinkedList(receiverActor) ++ receivers + + case UnsubscribeReceiver(receiverActor: ActorRef) => + println("received unsubscribe from %s".format(receiverActor.toString)) + receivers = receivers.dropWhile(x => x eq receiverActor) + + } +} + +/** + * A sample actor as receiver, is also simplest. This receiver actor + * goes and subscribe to a typical publisher/feeder actor and receives + * data. + * + * @see [[org.apache.spark.streaming.examples.FeederActor]] + */ +class SampleActorReceiver[T: ClassManifest](urlOfPublisher: String) +extends Actor with Receiver { + + lazy private val remotePublisher = context.actorFor(urlOfPublisher) + + override def preStart = remotePublisher ! SubscribeReceiver(context.self) + + def receive = { + case msg ⇒ context.parent ! pushBlock(msg.asInstanceOf[T]) + } + + override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) + +} + +/** + * A sample feeder actor + * + * Usage: FeederActor + * and describe the AkkaSystem that Spark Sample feeder would start on. + */ +object FeederActor { + + def main(args: Array[String]) { + if(args.length < 2){ + System.err.println( + "Usage: FeederActor \n" + ) + System.exit(1) + } + val Seq(host, port) = args.toSeq + + + val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt)._1 + val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") + + println("Feeder started as:" + feeder) + + actorSystem.awaitTermination(); + } +} + +/** + * A sample word count program demonstrating the use of plugging in + * Actor as Receiver + * Usage: ActorWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the AkkaSystem that Spark Sample feeder is running on. + * + * To run this example locally, you may run Feeder Actor as + * `$ ./run-example spark.streaming.examples.FeederActor 127.0.1.1 9999` + * and then run the example + * `$ ./run-example spark.streaming.examples.ActorWordCount local[2] 127.0.1.1 9999` + */ +object ActorWordCount { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: ActorWordCount " + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + val Seq(master, host, port) = args.toSeq + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "ActorWordCount", Seconds(2), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + /* + * Following is the use of actorStream to plug in custom actor as receiver + * + * An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e type of data received and InputDstream + * should be same. + * + * For example: Both actorStream and SampleActorReceiver are parameterized + * to same type to ensure type safety. + */ + + val lines = ssc.actorStream[String]( + Props(new SampleActorReceiver[String]("akka://test@%s:%s/user/FeederActor".format( + host, port.toInt))), "SampleReceiver") + + //compute wordcount + lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print() + + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala new file mode 100644 index 0000000000..9f6e163454 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.util.IntParam +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ + +/** + * Produces a count of events received from Flume. + * + * This should be used in conjunction with an AvroSink in Flume. It will start + * an Avro server on at the request host:port address and listen for requests. + * Your Flume AvroSink should be pointed to this address. + * + * Usage: FlumeEventCount + * + * is a Spark master URL + * is the host the Flume receiver will be started on - a receiver + * creates a server and listens for flume events. + * is the port the Flume receiver will listen on. + */ +object FlumeEventCount { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println( + "Usage: FlumeEventCount ") + System.exit(1) + } + + val Array(master, host, IntParam(port)) = args + + val batchInterval = Milliseconds(2000) + // Create the context and set the batch size + val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval, + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Create a flume stream + val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) + + // Print out the count of events received from this server in each batch + stream.count().map(cnt => "Received " + cnt + " flume events." ).print() + + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala new file mode 100644 index 0000000000..bc8564b3ba --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ + + +/** + * Counts words in new text files created in the given directory + * Usage: HdfsWordCount + * is the Spark master URL. + * is the directory that Spark Streaming will use to find and read new text files. + * + * To run this on your local machine on directory `localdir`, run this example + * `$ ./run-example spark.streaming.examples.HdfsWordCount local[2] localdir` + * Then create a text file in `localdir` and the words in the file will get counted. + */ +object HdfsWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: HdfsWordCount ") + System.exit(1) + } + + // Create the context + val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val lines = ssc.textFileStream(args(1)) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} + diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala new file mode 100644 index 0000000000..12f939d5a7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import java.util.Properties +import kafka.message.Message +import kafka.producer.SyncProducerConfig +import kafka.producer._ +import org.apache.spark.SparkContext +import org.apache.spark.streaming._ +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.util.RawTextHelper._ + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: KafkaWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * is a list of one or more zookeeper servers that make quorum + * is the name of kafka consumer group + * is a list of one or more kafka topics to consume from + * is the number of threads the kafka consumer should use + * + * Example: + * `./run-example spark.streaming.examples.KafkaWordCount local[2] zoo01,zoo02,zoo03 my-consumer-group topic1,topic2 1` + */ +object KafkaWordCount { + def main(args: Array[String]) { + + if (args.length < 5) { + System.err.println("Usage: KafkaWordCount ") + System.exit(1) + } + + val Array(master, zkQuorum, group, topics, numThreads) = args + + val ssc = new StreamingContext(master, "KafkaWordCount", Seconds(2), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + ssc.checkpoint("checkpoint") + + val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap + val lines = ssc.kafkaStream(zkQuorum, group, topicpMap) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) + wordCounts.print() + + ssc.start() + } +} + +// Produces some random words between 1 and 100. +object KafkaWordCountProducer { + + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: KafkaWordCountProducer ") + System.exit(1) + } + + val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args + + // Zookeper connection properties + val props = new Properties() + props.put("zk.connect", zkQuorum) + props.put("serializer.class", "kafka.serializer.StringEncoder") + + val config = new ProducerConfig(props) + val producer = new Producer[String, String](config) + + // Send some messages + while(true) { + val messages = (1 to messagesPerSec.toInt).map { messageNum => + (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ") + }.toArray + println(messages.mkString(",")) + val data = new ProducerData[String, String](topic, messages) + producer.send(data) + Thread.sleep(100) + } + } + +} + diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala new file mode 100644 index 0000000000..e2487dca5f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: NetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run-example spark.streaming.examples.NetworkWordCount local[2] localhost 9999` + */ +object NetworkWordCount { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: NetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + // Create the context with a 1 second batch size + val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.socketTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala new file mode 100644 index 0000000000..822da8c9b5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.RDD +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ + +import scala.collection.mutable.SynchronizedQueue + +object QueueStream { + + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: QueueStream ") + System.exit(1) + } + + // Create the context + val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Create the queue through which RDDs can be pushed to + // a QueueInputDStream + val rddQueue = new SynchronizedQueue[RDD[Int]]() + + // Create the QueueInputDStream and use it do some processing + val inputStream = ssc.queueStream(rddQueue) + val mappedStream = inputStream.map(x => (x % 10, 1)) + val reducedStream = mappedStream.reduceByKey(_ + _) + reducedStream.print() + ssc.start() + + // Create and push some RDDs into + for (i <- 1 to 30) { + rddQueue += ssc.sparkContext.makeRDD(1 to 1000, 10) + Thread.sleep(1000) + } + ssc.stop() + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala new file mode 100644 index 0000000000..2e3d9ccf00 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.util.IntParam +import org.apache.spark.storage.StorageLevel + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.util.RawTextHelper + +/** + * Receives text from multiple rawNetworkStreams and counts how many '\n' delimited + * lines have the word 'the' in them. This is useful for benchmarking purposes. This + * will only work with spark.streaming.util.RawTextSender running on all worker nodes + * and with Spark using Kryo serialization (set Java property "spark.serializer" to + * "org.apache.spark.KryoSerializer"). + * Usage: RawNetworkGrep + * is the Spark master URL + * is the number rawNetworkStreams, which should be same as number + * of work nodes in the cluster + * is "localhost". + * is the port on which RawTextSender is running in the worker nodes. + * is the Spark Streaming batch duration in milliseconds. + */ + +object RawNetworkGrep { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: RawNetworkGrep ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context + val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Warm up the JVMs on master and slave for JIT compilation to kick in + RawTextHelper.warmUp(ssc.sparkContext) + + val rawStreams = (1 to numStreams).map(_ => + ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray + val union = ssc.union(rawStreams) + union.filter(_.contains("the")).count().foreach(r => + println("Grep count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala new file mode 100644 index 0000000000..cb30c4edb3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.StreamingContext._ + +/** + * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: StatefulNetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run-example spark.streaming.examples.StatefulNetworkWordCount local[2] localhost 9999` + */ +object StatefulNetworkWordCount { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: StatefulNetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + val currentCount = values.foldLeft(0)(_ + _) + + val previousCount = state.getOrElse(0) + + Some(currentCount + previousCount) + } + + // Create the context with a 1 second batch size + val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(1), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + ssc.checkpoint(".") + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.socketTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordDstream = words.map(x => (x, 1)) + + // Update the cumulative count using updateStateByKey + // This will give a Dstream made of state (which is the cumulative count of the words) + val stateDstream = wordDstream.updateStateByKey[Int](updateFunc) + stateDstream.print() + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala new file mode 100644 index 0000000000..35b6329ab3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.storage.StorageLevel +import com.twitter.algebird._ +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.SparkContext._ + +/** + * Illustrates the use of the Count-Min Sketch, from Twitter's Algebird library, to compute + * windowed and global Top-K estimates of user IDs occurring in a Twitter stream. + *
    + * Note that since Algebird's implementation currently only supports Long inputs, + * the example operates on Long IDs. Once the implementation supports other inputs (such as String), + * the same approach could be used for computing popular topics for example. + *

    + *

    + * + * This blog post has a good overview of the Count-Min Sketch (CMS). The CMS is a datastructure + * for approximate frequency estimation in data streams (e.g. Top-K elements, frequency of any given element, etc), + * that uses space sub-linear in the number of elements in the stream. Once elements are added to the CMS, the + * estimated count of an element can be computed, as well as "heavy-hitters" that occur more than a threshold + * percentage of the overall total count. + *

    + * Algebird's implementation is a monoid, so we can succinctly merge two CMS instances in the reduce operation. + */ +object TwitterAlgebirdCMS { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: TwitterAlgebirdCMS " + + " [filter1] [filter2] ... [filter n]") + System.exit(1) + } + + // CMS parameters + val DELTA = 1E-3 + val EPS = 0.01 + val SEED = 1 + val PERC = 0.001 + // K highest frequency elements to take + val TOPK = 10 + + val (master, filters) = (args.head, args.tail) + + val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) + + val users = stream.map(status => status.getUser.getId) + + val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) + var globalCMS = cms.zero + val mm = new MapMonoid[Long, Int]() + var globalExact = Map[Long, Int]() + + val approxTopUsers = users.mapPartitions(ids => { + ids.map(id => cms.create(id)) + }).reduce(_ ++ _) + + val exactTopUsers = users.map(id => (id, 1)) + .reduceByKey((a, b) => a + b) + + approxTopUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partial = rdd.first() + val partialTopK = partial.heavyHitters.map(id => + (id, partial.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) + globalCMS ++= partial + val globalTopK = globalCMS.heavyHitters.map(id => + (id, globalCMS.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) + println("Approx heavy hitters at %2.2f%% threshold this batch: %s".format(PERC, + partialTopK.mkString("[", ",", "]"))) + println("Approx heavy hitters at %2.2f%% threshold overall: %s".format(PERC, + globalTopK.mkString("[", ",", "]"))) + } + }) + + exactTopUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partialMap = rdd.collect().toMap + val partialTopK = rdd.map( + {case (id, count) => (count, id)}) + .sortByKey(ascending = false).take(TOPK) + globalExact = mm.plus(globalExact.toMap, partialMap) + val globalTopK = globalExact.toSeq.sortBy(_._2).reverse.slice(0, TOPK) + println("Exact heavy hitters this batch: %s".format(partialTopK.mkString("[", ",", "]"))) + println("Exact heavy hitters overall: %s".format(globalTopK.mkString("[", ",", "]"))) + } + }) + + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala new file mode 100644 index 0000000000..8bfde2a829 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.storage.StorageLevel +import com.twitter.algebird.HyperLogLog._ +import com.twitter.algebird.HyperLogLogMonoid +import org.apache.spark.streaming.dstream.TwitterInputDStream + +/** + * Illustrates the use of the HyperLogLog algorithm, from Twitter's Algebird library, to compute + * a windowed and global estimate of the unique user IDs occurring in a Twitter stream. + *

    + *

    + * This + * blog post and this + * blog post + * have good overviews of HyperLogLog (HLL). HLL is a memory-efficient datastructure for estimating + * the cardinality of a data stream, i.e. the number of unique elements. + *

    + * Algebird's implementation is a monoid, so we can succinctly merge two HLL instances in the reduce operation. + */ +object TwitterAlgebirdHLL { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: TwitterAlgebirdHLL " + + " [filter1] [filter2] ... [filter n]") + System.exit(1) + } + + /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ + val BIT_SIZE = 12 + val (master, filters) = (args.head, args.tail) + + val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) + + val users = stream.map(status => status.getUser.getId) + + val hll = new HyperLogLogMonoid(BIT_SIZE) + var globalHll = hll.zero + var userSet: Set[Long] = Set() + + val approxUsers = users.mapPartitions(ids => { + ids.map(id => hll(id)) + }).reduce(_ + _) + + val exactUsers = users.map(id => Set(id)).reduce(_ ++ _) + + approxUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partial = rdd.first() + globalHll += partial + println("Approx distinct users this batch: %d".format(partial.estimatedSize.toInt)) + println("Approx distinct users overall: %d".format(globalHll.estimatedSize.toInt)) + } + }) + + exactUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partial = rdd.first() + userSet ++= partial + println("Exact distinct users this batch: %d".format(partial.size)) + println("Exact distinct users overall: %d".format(userSet.size)) + println("Error rate: %2.5f%%".format(((globalHll.estimatedSize / userSet.size.toDouble) - 1) * 100)) + } + }) + + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala new file mode 100644 index 0000000000..27aa6b14bf --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import StreamingContext._ +import org.apache.spark.SparkContext._ + +/** + * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter + * stream. The stream is instantiated with credentials and optionally filters supplied by the + * command line arguments. + * + */ +object TwitterPopularTags { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: TwitterPopularTags " + + " [filter1] [filter2] ... [filter n]") + System.exit(1) + } + + val (master, filters) = (args.head, args.tail) + + val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val stream = ssc.twitterStream(None, filters) + + val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) + + val topCounts60 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) + .map{case (topic, count) => (count, topic)} + .transform(_.sortByKey(false)) + + val topCounts10 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(10)) + .map{case (topic, count) => (count, topic)} + .transform(_.sortByKey(false)) + + + // Print popular hashtags + topCounts60.foreach(rdd => { + val topList = rdd.take(5) + println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) + topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} + }) + + topCounts10.foreach(rdd => { + val topList = rdd.take(5) + println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) + topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} + }) + + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala new file mode 100644 index 0000000000..c8743b9e25 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import akka.actor.ActorSystem +import akka.actor.actorRef2Scala +import akka.zeromq._ +import org.apache.spark.streaming.{ Seconds, StreamingContext } +import org.apache.spark.streaming.StreamingContext._ +import akka.zeromq.Subscribe + +/** + * A simple publisher for demonstration purposes, repeatedly publishes random Messages + * every one second. + */ +object SimpleZeroMQPublisher { + + def main(args: Array[String]) = { + if (args.length < 2) { + System.err.println("Usage: SimpleZeroMQPublisher ") + System.exit(1) + } + + val Seq(url, topic) = args.toSeq + val acs: ActorSystem = ActorSystem() + + val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) + val messages: Array[String] = Array("words ", "may ", "count ") + while (true) { + Thread.sleep(1000) + pubSocket ! ZMQMessage(Frame(topic) :: messages.map(x => Frame(x.getBytes)).toList) + } + acs.awaitTermination() + } +} + +/** + * A sample wordcount with ZeroMQStream stream + * + * To work with zeroMQ, some native libraries have to be installed. + * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide](http://www.zeromq.org/intro:get-the-software) + * + * Usage: ZeroMQWordCount + * In local mode, should be 'local[n]' with n > 1 + * and describe where zeroMq publisher is running. + * + * To run this example locally, you may run publisher as + * `$ ./run-example spark.streaming.examples.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` + * and run the example as + * `$ ./run-example spark.streaming.examples.ZeroMQWordCount local[2] tcp://127.0.1.1:1234 foo` + */ +object ZeroMQWordCount { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: ZeroMQWordCount " + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + val Seq(master, url, topic) = args.toSeq + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "ZeroMQWordCount", Seconds(2), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + def bytesToStringIterator(x: Seq[Seq[Byte]]) = (x.map(x => new String(x.toArray))).iterator + + //For this stream, a zeroMQ publisher should be running. + val lines = ssc.zeroMQStream(url, Subscribe(topic), bytesToStringIterator) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala new file mode 100644 index 0000000000..884d6d6f34 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples.clickstream + +import java.net.{InetAddress,ServerSocket,Socket,SocketException} +import java.io.{InputStreamReader, BufferedReader, PrintWriter} +import util.Random + +/** Represents a page view on a website with associated dimension data.*/ +class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) { + override def toString() : String = { + "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID) + } +} +object PageView { + def fromString(in : String) : PageView = { + val parts = in.split("\t") + new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt) + } +} + +/** Generates streaming events to simulate page views on a website. + * + * This should be used in tandem with PageViewStream.scala. Example: + * $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10 + * $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 + * */ +object PageViewGenerator { + val pages = Map("http://foo.com/" -> .7, + "http://foo.com/news" -> 0.2, + "http://foo.com/contact" -> .1) + val httpStatus = Map(200 -> .95, + 404 -> .05) + val userZipCode = Map(94709 -> .5, + 94117 -> .5) + val userID = Map((1 to 100).map(_ -> .01):_*) + + + def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { + val rand = new Random().nextDouble() + var total = 0.0 + for ((item, prob) <- inputMap) { + total = total + prob + if (total > rand) { + return item + } + } + return inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 + } + + def getNextClickEvent() : String = { + val id = pickFromDistribution(userID) + val page = pickFromDistribution(pages) + val status = pickFromDistribution(httpStatus) + val zipCode = pickFromDistribution(userZipCode) + new PageView(page, status, zipCode, id).toString() + } + + def main(args : Array[String]) { + if (args.length != 2) { + System.err.println("Usage: PageViewGenerator ") + System.exit(1) + } + val port = args(0).toInt + val viewsPerSecond = args(1).toFloat + val sleepDelayMs = (1000.0 / viewsPerSecond).toInt + val listener = new ServerSocket(port) + println("Listening on port: " + port) + + while (true) { + val socket = listener.accept() + new Thread() { + override def run = { + println("Got client connected from: " + socket.getInetAddress) + val out = new PrintWriter(socket.getOutputStream(), true) + + while (true) { + Thread.sleep(sleepDelayMs) + out.write(getNextClickEvent()) + out.flush() + } + socket.close() + } + }.start() + } + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala new file mode 100644 index 0000000000..8282cc9269 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples.clickstream + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.SparkContext._ + +/** Analyses a streaming dataset of web page views. This class demonstrates several types of + * operators available in Spark streaming. + * + * This should be used in tandem with PageViewStream.scala. Example: + * $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10 + * $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 + */ +object PageViewStream { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println("Usage: PageViewStream ") + System.err.println(" must be one of pageCounts, slidingPageCounts," + + " errorRatePerZipCode, activeUserCount, popularUsersSeen") + System.exit(1) + } + val metric = args(0) + val host = args(1) + val port = args(2).toInt + + // Create the context + val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Create a NetworkInputDStream on target host:port and convert each line to a PageView + val pageViews = ssc.socketTextStream(host, port) + .flatMap(_.split("\n")) + .map(PageView.fromString(_)) + + // Return a count of views per URL seen in each batch + val pageCounts = pageViews.map(view => view.url).countByValue() + + // Return a sliding window of page views per URL in the last ten seconds + val slidingPageCounts = pageViews.map(view => view.url) + .countByValueAndWindow(Seconds(10), Seconds(2)) + + + // Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds + val statusesPerZipCode = pageViews.window(Seconds(30), Seconds(2)) + .map(view => ((view.zipCode, view.status))) + .groupByKey() + val errorRatePerZipCode = statusesPerZipCode.map{ + case(zip, statuses) => + val normalCount = statuses.filter(_ == 200).size + val errorCount = statuses.size - normalCount + val errorRatio = errorCount.toFloat / statuses.size + if (errorRatio > 0.05) {"%s: **%s**".format(zip, errorRatio)} + else {"%s: %s".format(zip, errorRatio)} + } + + // Return the number unique users in last 15 seconds + val activeUserCount = pageViews.window(Seconds(15), Seconds(2)) + .map(view => (view.userID, 1)) + .groupByKey() + .count() + .map("Unique active users: " + _) + + // An external dataset we want to join to this stream + val userList = ssc.sparkContext.parallelize( + Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq) + + metric match { + case "pageCounts" => pageCounts.print() + case "slidingPageCounts" => slidingPageCounts.print() + case "errorRatePerZipCode" => errorRatePerZipCode.print() + case "activeUserCount" => activeUserCount.print() + case "popularUsersSeen" => + // Look for users in our existing dataset and print it out if we have a match + pageViews.map(view => (view.userID, 1)) + .foreach((rdd, time) => rdd.join(userList) + .map(_._2._2) + .take(10) + .foreach(u => println("Saw user %s at time %s".format(u, time)))) + case _ => println("Invalid metric entered: " + metric) + } + + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/examples/BroadcastTest.scala b/examples/src/main/scala/spark/examples/BroadcastTest.scala deleted file mode 100644 index 911490cb6c..0000000000 --- a/examples/src/main/scala/spark/examples/BroadcastTest.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext - -object BroadcastTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: BroadcastTest [] [numElem]") - System.exit(1) - } - - val sc = new SparkContext(args(0), "Broadcast Test", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val slices = if (args.length > 1) args(1).toInt else 2 - val num = if (args.length > 2) args(2).toInt else 1000000 - - var arr1 = new Array[Int](num) - for (i <- 0 until arr1.length) { - arr1(i) = i - } - - for (i <- 0 until 2) { - println("Iteration " + i) - println("===========") - val barr1 = sc.broadcast(arr1) - sc.parallelize(1 to 10, slices).foreach { - i => println(barr1.value.size) - } - } - - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala deleted file mode 100644 index 104bfd5204..0000000000 --- a/examples/src/main/scala/spark/examples/CassandraTest.scala +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import org.apache.hadoop.mapreduce.Job -import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat -import org.apache.cassandra.hadoop.ConfigHelper -import org.apache.cassandra.hadoop.ColumnFamilyInputFormat -import org.apache.cassandra.thrift._ -import spark.SparkContext -import spark.SparkContext._ -import java.nio.ByteBuffer -import java.util.SortedMap -import org.apache.cassandra.db.IColumn -import org.apache.cassandra.utils.ByteBufferUtil -import scala.collection.JavaConversions._ - - -/* - * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra - * support for Hadoop. - * - * To run this example, run this file with the following command params - - * - * - * So if you want to run this on localhost this will be, - * local[3] localhost 9160 - * - * The example makes some assumptions: - * 1. You have already created a keyspace called casDemo and it has a column family named Words - * 2. There are column family has a column named "para" which has test content. - * - * You can create the content by running the following script at the bottom of this file with - * cassandra-cli. - * - */ -object CassandraTest { - - def main(args: Array[String]) { - - // Get a SparkContext - val sc = new SparkContext(args(0), "casDemo") - - // Build the job configuration with ConfigHelper provided by Cassandra - val job = new Job() - job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) - - val host: String = args(1) - val port: String = args(2) - - ConfigHelper.setInputInitialAddress(job.getConfiguration(), host) - ConfigHelper.setInputRpcPort(job.getConfiguration(), port) - ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host) - ConfigHelper.setOutputRpcPort(job.getConfiguration(), port) - ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") - ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount") - - val predicate = new SlicePredicate() - val sliceRange = new SliceRange() - sliceRange.setStart(Array.empty[Byte]) - sliceRange.setFinish(Array.empty[Byte]) - predicate.setSlice_range(sliceRange) - ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate) - - ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") - ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") - - // Make a new Hadoop RDD - val casRdd = sc.newAPIHadoopRDD( - job.getConfiguration(), - classOf[ColumnFamilyInputFormat], - classOf[ByteBuffer], - classOf[SortedMap[ByteBuffer, IColumn]]) - - // Let us first get all the paragraphs from the retrieved rows - val paraRdd = casRdd.map { - case (key, value) => { - ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) - } - } - - // Lets get the word count in paras - val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) - - counts.collect().foreach { - case (word, count) => println(word + ":" + count) - } - - counts.map { - case (word, count) => { - val colWord = new org.apache.cassandra.thrift.Column() - colWord.setName(ByteBufferUtil.bytes("word")) - colWord.setValue(ByteBufferUtil.bytes(word)) - colWord.setTimestamp(System.currentTimeMillis) - - val colCount = new org.apache.cassandra.thrift.Column() - colCount.setName(ByteBufferUtil.bytes("wcount")) - colCount.setValue(ByteBufferUtil.bytes(count.toLong)) - colCount.setTimestamp(System.currentTimeMillis) - - val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) - - val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil - mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) - mutations.get(0).column_or_supercolumn.setColumn(colWord) - mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) - mutations.get(1).column_or_supercolumn.setColumn(colCount) - (outputkey, mutations) - } - }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], - classOf[ColumnFamilyOutputFormat], job.getConfiguration) - } -} - -/* -create keyspace casDemo; -use casDemo; - -create column family WordCount with comparator = UTF8Type; -update column family WordCount with column_metadata = - [{column_name: word, validation_class: UTF8Type}, - {column_name: wcount, validation_class: LongType}]; - -create column family Words with comparator = UTF8Type; -update column family Words with column_metadata = - [{column_name: book, validation_class: UTF8Type}, - {column_name: para, validation_class: UTF8Type}]; - -assume Words keys as utf8; - -set Words['3musk001']['book'] = 'The Three Musketeers'; -set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market - town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to - be in as perfect a state of revolution as if the Huguenots had just made - a second La Rochelle of it. Many citizens, seeing the women flying - toward the High Street, leaving their children crying at the open doors, - hastened to don the cuirass, and supporting their somewhat uncertain - courage with a musket or a partisan, directed their steps toward the - hostelry of the Jolly Miller, before which was gathered, increasing - every minute, a compact group, vociferous and full of curiosity.'; - -set Words['3musk002']['book'] = 'The Three Musketeers'; -set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without - some city or other registering in its archives an event of this kind. There were - nobles, who made war against each other; there was the king, who made - war against the cardinal; there was Spain, which made war against the - king. Then, in addition to these concealed or public, secret or open - wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels, - who made war upon everybody. The citizens always took up arms readily - against thieves, wolves or scoundrels, often against nobles or - Huguenots, sometimes against the king, but never against cardinal or - Spain. It resulted, then, from this habit that on the said first Monday - of April, 1625, the citizens, on hearing the clamor, and seeing neither - the red-and-yellow standard nor the livery of the Duc de Richelieu, - rushed toward the hostel of the Jolly Miller. When arrived there, the - cause of the hubbub was apparent to all'; - -set Words['3musk003']['book'] = 'The Three Musketeers'; -set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however - large the sum may be; but you ought also to endeavor to perfect yourself in - the exercises becoming a gentleman. I will write a letter today to the - Director of the Royal Academy, and tomorrow he will admit you without - any expense to yourself. Do not refuse this little service. Our - best-born and richest gentlemen sometimes solicit it without being able - to obtain it. You will learn horsemanship, swordsmanship in all its - branches, and dancing. You will make some desirable acquaintances; and - from time to time you can call upon me, just to tell me how you are - getting on, and to say whether I can be of further service to you.'; - - -set Words['thelostworld001']['book'] = 'The Lost World'; -set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined - against the red curtain. How beautiful she was! And yet how aloof! We had been - friends, quite good friends; but never could I get beyond the same - comradeship which I might have established with one of my - fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly, - and perfectly unsexual. My instincts are all against a woman being too - frank and at her ease with me. It is no compliment to a man. Where - the real sex feeling begins, timidity and distrust are its companions, - heritage from old wicked days when love and violence went often hand in - hand. The bent head, the averted eye, the faltering voice, the wincing - figure--these, and not the unshrinking gaze and frank reply, are the - true signals of passion. Even in my short life I had learned as much - as that--or had inherited it in that race memory which we call instinct.'; - -set Words['thelostworld002']['book'] = 'The Lost World'; -set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, - red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was - the real boss; but he lived in the rarefied atmosphere of some Olympian - height from which he could distinguish nothing smaller than an - international crisis or a split in the Cabinet. Sometimes we saw him - passing in lonely majesty to his inner sanctum, with his eyes staring - vaguely and his mind hovering over the Balkans or the Persian Gulf. He - was above and beyond us. But McArdle was his first lieutenant, and it - was he that we knew. The old man nodded as I entered the room, and he - pushed his spectacles far up on his bald forehead.'; - -*/ diff --git a/examples/src/main/scala/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/spark/examples/ExceptionHandlingTest.scala deleted file mode 100644 index 67ddaec8d2..0000000000 --- a/examples/src/main/scala/spark/examples/ExceptionHandlingTest.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext - -object ExceptionHandlingTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: ExceptionHandlingTest ") - System.exit(1) - } - - val sc = new SparkContext(args(0), "ExceptionHandlingTest", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - sc.parallelize(0 until sc.defaultParallelism).foreach { i => - if (math.random > 0.75) - throw new Exception("Testing exception handling") - } - - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/GroupByTest.scala b/examples/src/main/scala/spark/examples/GroupByTest.scala deleted file mode 100644 index 5cee413615..0000000000 --- a/examples/src/main/scala/spark/examples/GroupByTest.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext -import spark.SparkContext._ -import java.util.Random - -object GroupByTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") - System.exit(1) - } - - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => - val ranGen = new Random - var arr1 = new Array[(Int, Array[Byte])](numKVPairs) - for (i <- 0 until numKVPairs) { - val byteArr = new Array[Byte](valSize) - ranGen.nextBytes(byteArr) - arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) - } - arr1 - }.cache - // Enforce that everything has been calculated and in cache - pairs1.count - - println(pairs1.groupByKey(numReducers).count) - - System.exit(0) - } -} - diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala deleted file mode 100644 index 4dd6c243ac..0000000000 --- a/examples/src/main/scala/spark/examples/HBaseTest.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark._ -import spark.rdd.NewHadoopRDD -import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} -import org.apache.hadoop.hbase.client.HBaseAdmin -import org.apache.hadoop.hbase.mapreduce.TableInputFormat - -object HBaseTest { - def main(args: Array[String]) { - val sc = new SparkContext(args(0), "HBaseTest", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val conf = HBaseConfiguration.create() - - // Other options for configuring scan behavior are available. More information available at - // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html - conf.set(TableInputFormat.INPUT_TABLE, args(1)) - - // Initialize hBase table if necessary - val admin = new HBaseAdmin(conf) - if(!admin.isTableAvailable(args(1))) { - val tableDesc = new HTableDescriptor(args(1)) - admin.createTable(tableDesc) - } - - val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], - classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], - classOf[org.apache.hadoop.hbase.client.Result]) - - hBaseRDD.count() - - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/HdfsTest.scala b/examples/src/main/scala/spark/examples/HdfsTest.scala deleted file mode 100644 index 23258336e2..0000000000 --- a/examples/src/main/scala/spark/examples/HdfsTest.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark._ - -object HdfsTest { - def main(args: Array[String]) { - val sc = new SparkContext(args(0), "HdfsTest", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val file = sc.textFile(args(1)) - val mapped = file.map(s => s.length).cache() - for (iter <- 1 to 10) { - val start = System.currentTimeMillis() - for (x <- mapped) { x + 2 } - // println("Processing: " + x) - val end = System.currentTimeMillis() - println("Iteration " + iter + " took " + (end-start) + " ms") - } - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/LocalALS.scala b/examples/src/main/scala/spark/examples/LocalALS.scala deleted file mode 100644 index 7a449a9d72..0000000000 --- a/examples/src/main/scala/spark/examples/LocalALS.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import scala.math.sqrt -import cern.jet.math._ -import cern.colt.matrix._ -import cern.colt.matrix.linalg._ - -/** - * Alternating least squares matrix factorization. - */ -object LocalALS { - // Parameters set through command line arguments - var M = 0 // Number of movies - var U = 0 // Number of users - var F = 0 // Number of features - var ITERATIONS = 0 - - val LAMBDA = 0.01 // Regularization coefficient - - // Some COLT objects - val factory2D = DoubleFactory2D.dense - val factory1D = DoubleFactory1D.dense - val algebra = Algebra.DEFAULT - val blas = SeqBlas.seqBlas - - def generateR(): DoubleMatrix2D = { - val mh = factory2D.random(M, F) - val uh = factory2D.random(U, F) - return algebra.mult(mh, algebra.transpose(uh)) - } - - def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], - us: Array[DoubleMatrix1D]): Double = - { - val r = factory2D.make(M, U) - for (i <- 0 until M; j <- 0 until U) { - r.set(i, j, blas.ddot(ms(i), us(j))) - } - //println("R: " + r) - blas.daxpy(-1, targetR, r) - val sumSqs = r.aggregate(Functions.plus, Functions.square) - return sqrt(sumSqs / (M * U)) - } - - def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each user that rated the movie - for (j <- 0 until U) { - val u = us(j) - // Add u * u^t to XtX - blas.dger(1, u, u, XtX) - // Add u * rating to Xty - blas.daxpy(R.get(i, j), u, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - - def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each movie that the user rated - for (i <- 0 until M) { - val m = ms(i) - // Add m * m^t to XtX - blas.dger(1, m, m, XtX) - // Add m * rating to Xty - blas.daxpy(R.get(i, j), m, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - - def main(args: Array[String]) { - args match { - case Array(m, u, f, iters) => { - M = m.toInt - U = u.toInt - F = f.toInt - ITERATIONS = iters.toInt - } - case _ => { - System.err.println("Usage: LocalALS ") - System.exit(1) - } - } - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); - - val R = generateR() - - // Initialize m and u randomly - var ms = Array.fill(M)(factory1D.random(F)) - var us = Array.fill(U)(factory1D.random(F)) - - // Iteratively update movies then users - for (iter <- 1 to ITERATIONS) { - println("Iteration " + iter + ":") - ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray - us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray - println("RMSE = " + rmse(R, ms, us)) - println() - } - } -} diff --git a/examples/src/main/scala/spark/examples/LocalFileLR.scala b/examples/src/main/scala/spark/examples/LocalFileLR.scala deleted file mode 100644 index c1f8d32aa8..0000000000 --- a/examples/src/main/scala/spark/examples/LocalFileLR.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import spark.util.Vector - -object LocalFileLR { - val D = 10 // Numer of dimensions - val rand = new Random(42) - - case class DataPoint(x: Vector, y: Double) - - def parsePoint(line: String): DataPoint = { - val nums = line.split(' ').map(_.toDouble) - return DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) - } - - def main(args: Array[String]) { - val lines = scala.io.Source.fromFile(args(0)).getLines().toArray - val points = lines.map(parsePoint _) - val ITERATIONS = args(1).toInt - - // Initialize w to a random value - var w = Vector(D, _ => 2 * rand.nextDouble - 1) - println("Initial w: " + w) - - for (i <- 1 to ITERATIONS) { - println("On iteration " + i) - var gradient = Vector.zeros(D) - for (p <- points) { - val scale = (1 / (1 + math.exp(-p.y * (w dot p.x))) - 1) * p.y - gradient += scale * p.x - } - w -= gradient - } - - println("Final w: " + w) - } -} diff --git a/examples/src/main/scala/spark/examples/LocalKMeans.scala b/examples/src/main/scala/spark/examples/LocalKMeans.scala deleted file mode 100644 index 0a0bc6f476..0000000000 --- a/examples/src/main/scala/spark/examples/LocalKMeans.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import spark.util.Vector -import spark.SparkContext._ -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -/** - * K-means clustering. - */ -object LocalKMeans { - val N = 1000 - val R = 1000 // Scaling factor - val D = 10 - val K = 10 - val convergeDist = 0.001 - val rand = new Random(42) - - def generateData = { - def generatePoint(i: Int) = { - Vector(D, _ => rand.nextDouble * R) - } - Array.tabulate(N)(generatePoint) - } - - def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { - var index = 0 - var bestIndex = 0 - var closest = Double.PositiveInfinity - - for (i <- 1 to centers.size) { - val vCurr = centers.get(i).get - val tempDist = p.squaredDist(vCurr) - if (tempDist < closest) { - closest = tempDist - bestIndex = i - } - } - - return bestIndex - } - - def main(args: Array[String]) { - val data = generateData - var points = new HashSet[Vector] - var kPoints = new HashMap[Int, Vector] - var tempDist = 1.0 - - while (points.size < K) { - points.add(data(rand.nextInt(N))) - } - - val iter = points.iterator - for (i <- 1 to points.size) { - kPoints.put(i, iter.next()) - } - - println("Initial centers: " + kPoints) - - while(tempDist > convergeDist) { - var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - - var mappings = closest.groupBy[Int] (x => x._1) - - var pointStats = mappings.map(pair => pair._2.reduceLeft [(Int, (Vector, Int))] {case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1+y2))}) - - var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)} - - tempDist = 0.0 - for (mapping <- newPoints) { - tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2) - } - - for (newP <- newPoints) { - kPoints.put(newP._1, newP._2) - } - } - - println("Final centers: " + kPoints) - } -} diff --git a/examples/src/main/scala/spark/examples/LocalLR.scala b/examples/src/main/scala/spark/examples/LocalLR.scala deleted file mode 100644 index ab99bf1fbe..0000000000 --- a/examples/src/main/scala/spark/examples/LocalLR.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import spark.util.Vector - -/** - * Logistic regression based classification. - */ -object LocalLR { - val N = 10000 // Number of data points - val D = 10 // Number of dimensions - val R = 0.7 // Scaling factor - val ITERATIONS = 5 - val rand = new Random(42) - - case class DataPoint(x: Vector, y: Double) - - def generateData = { - def generatePoint(i: Int) = { - val y = if(i % 2 == 0) -1 else 1 - val x = Vector(D, _ => rand.nextGaussian + y * R) - DataPoint(x, y) - } - Array.tabulate(N)(generatePoint) - } - - def main(args: Array[String]) { - val data = generateData - - // Initialize w to a random value - var w = Vector(D, _ => 2 * rand.nextDouble - 1) - println("Initial w: " + w) - - for (i <- 1 to ITERATIONS) { - println("On iteration " + i) - var gradient = Vector.zeros(D) - for (p <- data) { - val scale = (1 / (1 + math.exp(-p.y * (w dot p.x))) - 1) * p.y - gradient += scale * p.x - } - w -= gradient - } - - println("Final w: " + w) - } -} diff --git a/examples/src/main/scala/spark/examples/LocalPi.scala b/examples/src/main/scala/spark/examples/LocalPi.scala deleted file mode 100644 index ccd69695df..0000000000 --- a/examples/src/main/scala/spark/examples/LocalPi.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import scala.math.random -import spark._ -import SparkContext._ - -object LocalPi { - def main(args: Array[String]) { - var count = 0 - for (i <- 1 to 100000) { - val x = random * 2 - 1 - val y = random * 2 - 1 - if (x*x + y*y < 1) count += 1 - } - println("Pi is roughly " + 4 * count / 100000.0) - } -} diff --git a/examples/src/main/scala/spark/examples/LogQuery.scala b/examples/src/main/scala/spark/examples/LogQuery.scala deleted file mode 100644 index e815ececf7..0000000000 --- a/examples/src/main/scala/spark/examples/LogQuery.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext -import spark.SparkContext._ -/** - * Executes a roll up-style query against Apache logs. - */ -object LogQuery { - val exampleApacheLogs = List( - """10.10.10.10 - "FRED" [18/Jan/2013:17:56:07 +1100] "GET http://images.com/2013/Generic.jpg - | HTTP/1.1" 304 315 "http://referall.com/" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; - | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR - | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR - | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.350 "-" - "" 265 923 934 "" - | 62.24.11.25 images.com 1358492167 - Whatup""".stripMargin.replace("\n", ""), - """10.10.10.10 - "FRED" [18/Jan/2013:18:02:37 +1100] "GET http://images.com/2013/Generic.jpg - | HTTP/1.1" 304 306 "http:/referall.com" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; - | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR - | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR - | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.352 "-" - "" 256 977 988 "" - | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.replace("\n", "") - ) - - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: LogQuery [logFile]") - System.exit(1) - } - - val sc = new SparkContext(args(0), "Log Query", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val dataSet = - if (args.length == 2) sc.textFile(args(1)) - else sc.parallelize(exampleApacheLogs) - - val apacheLogRegex = - """^([\d.]+) (\S+) (\S+) \[([\w\d:/]+\s[+\-]\d{4})\] "(.+?)" (\d{3}) ([\d\-]+) "([^"]+)" "([^"]+)".*""".r - - /** Tracks the total query count and number of aggregate bytes for a particular group. */ - class Stats(val count: Int, val numBytes: Int) extends Serializable { - def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) - override def toString = "bytes=%s\tn=%s".format(numBytes, count) - } - - def extractKey(line: String): (String, String, String) = { - apacheLogRegex.findFirstIn(line) match { - case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) => - if (user != "\"-\"") (ip, user, query) - else (null, null, null) - case _ => (null, null, null) - } - } - - def extractStats(line: String): Stats = { - apacheLogRegex.findFirstIn(line) match { - case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) => - new Stats(1, bytes.toInt) - case _ => new Stats(1, 0) - } - } - - dataSet.map(line => (extractKey(line), extractStats(line))) - .reduceByKey((a, b) => a.merge(b)) - .collect().foreach{ - case (user, query) => println("%s\t%s".format(user, query))} - } -} diff --git a/examples/src/main/scala/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/spark/examples/MultiBroadcastTest.scala deleted file mode 100644 index d0b1cf06e5..0000000000 --- a/examples/src/main/scala/spark/examples/MultiBroadcastTest.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext - -object MultiBroadcastTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: BroadcastTest [] [numElem]") - System.exit(1) - } - - val sc = new SparkContext(args(0), "Broadcast Test", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val slices = if (args.length > 1) args(1).toInt else 2 - val num = if (args.length > 2) args(2).toInt else 1000000 - - var arr1 = new Array[Int](num) - for (i <- 0 until arr1.length) { - arr1(i) = i - } - - var arr2 = new Array[Int](num) - for (i <- 0 until arr2.length) { - arr2(i) = i - } - - val barr1 = sc.broadcast(arr1) - val barr2 = sc.broadcast(arr2) - sc.parallelize(1 to 10, slices).foreach { - i => println(barr1.value.size + barr2.value.size) - } - - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala deleted file mode 100644 index d197bbaf7c..0000000000 --- a/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext -import spark.SparkContext._ -import java.util.Random - -object SimpleSkewedGroupByTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SimpleSkewedGroupByTest " + - "[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]") - System.exit(1) - } - - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - var ratio = if (args.length > 5) args(5).toInt else 5.0 - - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => - val ranGen = new Random - var result = new Array[(Int, Array[Byte])](numKVPairs) - for (i <- 0 until numKVPairs) { - val byteArr = new Array[Byte](valSize) - ranGen.nextBytes(byteArr) - val offset = ranGen.nextInt(1000) * numReducers - if (ranGen.nextDouble < ratio / (numReducers + ratio - 1)) { - // give ratio times higher chance of generating key 0 (for reducer 0) - result(i) = (offset, byteArr) - } else { - // generate a key for one of the other reducers - val key = 1 + ranGen.nextInt(numReducers-1) + offset - result(i) = (key, byteArr) - } - } - result - }.cache - // Enforce that everything has been calculated and in cache - pairs1.count - - println("RESULT: " + pairs1.groupByKey(numReducers).count) - // Print how many keys each reducer got (for debugging) - //println("RESULT: " + pairs1.groupByKey(numReducers) - // .map{case (k,v) => (k, v.size)} - // .collectAsMap) - - System.exit(0) - } -} - diff --git a/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala deleted file mode 100644 index 4641b82444..0000000000 --- a/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext -import spark.SparkContext._ -import java.util.Random - -object SkewedGroupByTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") - System.exit(1) - } - - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => - val ranGen = new Random - - // map output sizes lineraly increase from the 1st to the last - numKVPairs = (1.0 * (p + 1) / numMappers * numKVPairs).toInt - - var arr1 = new Array[(Int, Array[Byte])](numKVPairs) - for (i <- 0 until numKVPairs) { - val byteArr = new Array[Byte](valSize) - ranGen.nextBytes(byteArr) - arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) - } - arr1 - }.cache() - // Enforce that everything has been calculated and in cache - pairs1.count() - - println(pairs1.groupByKey(numReducers).count()) - - System.exit(0) - } -} - diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala deleted file mode 100644 index ba0dfd8f9b..0000000000 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import scala.math.sqrt -import cern.jet.math._ -import cern.colt.matrix._ -import cern.colt.matrix.linalg._ -import spark._ - -/** - * Alternating least squares matrix factorization. - */ -object SparkALS { - // Parameters set through command line arguments - var M = 0 // Number of movies - var U = 0 // Number of users - var F = 0 // Number of features - var ITERATIONS = 0 - - val LAMBDA = 0.01 // Regularization coefficient - - // Some COLT objects - val factory2D = DoubleFactory2D.dense - val factory1D = DoubleFactory1D.dense - val algebra = Algebra.DEFAULT - val blas = SeqBlas.seqBlas - - def generateR(): DoubleMatrix2D = { - val mh = factory2D.random(M, F) - val uh = factory2D.random(U, F) - return algebra.mult(mh, algebra.transpose(uh)) - } - - def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], - us: Array[DoubleMatrix1D]): Double = - { - val r = factory2D.make(M, U) - for (i <- 0 until M; j <- 0 until U) { - r.set(i, j, blas.ddot(ms(i), us(j))) - } - //println("R: " + r) - blas.daxpy(-1, targetR, r) - val sumSqs = r.aggregate(Functions.plus, Functions.square) - return sqrt(sumSqs / (M * U)) - } - - def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val U = us.size - val F = us(0).size - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each user that rated the movie - for (j <- 0 until U) { - val u = us(j) - // Add u * u^t to XtX - blas.dger(1, u, u, XtX) - // Add u * rating to Xty - blas.daxpy(R.get(i, j), u, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkALS [ ]") - System.exit(1) - } - - var host = "" - var slices = 0 - - val options = (0 to 5).map(i => if (i < args.length) Some(args(i)) else None) - - options.toArray match { - case Array(host_, m, u, f, iters, slices_) => - host = host_.get - M = m.getOrElse("100").toInt - U = u.getOrElse("500").toInt - F = f.getOrElse("10").toInt - ITERATIONS = iters.getOrElse("5").toInt - slices = slices_.getOrElse("2").toInt - case _ => - System.err.println("Usage: SparkALS [ ]") - System.exit(1) - } - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) - - val sc = new SparkContext(host, "SparkALS", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val R = generateR() - - // Initialize m and u randomly - var ms = Array.fill(M)(factory1D.random(F)) - var us = Array.fill(U)(factory1D.random(F)) - - // Iteratively update movies then users - val Rc = sc.broadcast(R) - var msb = sc.broadcast(ms) - var usb = sc.broadcast(us) - for (iter <- 1 to ITERATIONS) { - println("Iteration " + iter + ":") - ms = sc.parallelize(0 until M, slices) - .map(i => update(i, msb.value(i), usb.value, Rc.value)) - .toArray - msb = sc.broadcast(ms) // Re-broadcast ms because it was updated - us = sc.parallelize(0 until U, slices) - .map(i => update(i, usb.value(i), msb.value, algebra.transpose(Rc.value))) - .toArray - usb = sc.broadcast(us) // Re-broadcast us because it was updated - println("RMSE = " + rmse(R, ms, us)) - println() - } - - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala deleted file mode 100644 index 43c9115664..0000000000 --- a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import scala.math.exp -import spark.util.Vector -import spark._ -import spark.scheduler.InputFormatInfo - -/** - * Logistic regression based classification. - */ -object SparkHdfsLR { - val D = 10 // Numer of dimensions - val rand = new Random(42) - - case class DataPoint(x: Vector, y: Double) - - def parsePoint(line: String): DataPoint = { - //val nums = line.split(' ').map(_.toDouble) - //return DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) - val tok = new java.util.StringTokenizer(line, " ") - var y = tok.nextToken.toDouble - var x = new Array[Double](D) - var i = 0 - while (i < D) { - x(i) = tok.nextToken.toDouble; i += 1 - } - return DataPoint(new Vector(x), y) - } - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: SparkHdfsLR ") - System.exit(1) - } - val inputPath = args(1) - val conf = SparkEnv.get.hadoop.newConfiguration() - val sc = new SparkContext(args(0), "SparkHdfsLR", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")), Map(), - InputFormatInfo.computePreferredLocations( - Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)))) - val lines = sc.textFile(inputPath) - val points = lines.map(parsePoint _).cache() - val ITERATIONS = args(2).toInt - - // Initialize w to a random value - var w = Vector(D, _ => 2 * rand.nextDouble - 1) - println("Initial w: " + w) - - for (i <- 1 to ITERATIONS) { - println("On iteration " + i) - val gradient = points.map { p => - (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x - }.reduce(_ + _) - w -= gradient - } - - println("Final w: " + w) - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala deleted file mode 100644 index 38ed3b149a..0000000000 --- a/examples/src/main/scala/spark/examples/SparkKMeans.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import spark.SparkContext -import spark.util.Vector -import spark.SparkContext._ -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -/** - * K-means clustering. - */ -object SparkKMeans { - val R = 1000 // Scaling factor - val rand = new Random(42) - - def parseVector(line: String): Vector = { - return new Vector(line.split(' ').map(_.toDouble)) - } - - def closestPoint(p: Vector, centers: Array[Vector]): Int = { - var index = 0 - var bestIndex = 0 - var closest = Double.PositiveInfinity - - for (i <- 0 until centers.length) { - val tempDist = p.squaredDist(centers(i)) - if (tempDist < closest) { - closest = tempDist - bestIndex = i - } - } - - return bestIndex - } - - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: SparkLocalKMeans ") - System.exit(1) - } - val sc = new SparkContext(args(0), "SparkLocalKMeans", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val lines = sc.textFile(args(1)) - val data = lines.map(parseVector _).cache() - val K = args(2).toInt - val convergeDist = args(3).toDouble - - var kPoints = data.takeSample(false, K, 42).toArray - var tempDist = 1.0 - - while(tempDist > convergeDist) { - var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - - var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} - - var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap() - - tempDist = 0.0 - for (i <- 0 until K) { - tempDist += kPoints(i).squaredDist(newPoints(i)) - } - - for (newP <- newPoints) { - kPoints(newP._1) = newP._2 - } - println("Finished iteration (delta = " + tempDist + ")") - } - - println("Final centers:") - kPoints.foreach(println) - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SparkLR.scala b/examples/src/main/scala/spark/examples/SparkLR.scala deleted file mode 100644 index 52a0d69744..0000000000 --- a/examples/src/main/scala/spark/examples/SparkLR.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import scala.math.exp -import spark.util.Vector -import spark._ - -/** - * Logistic regression based classification. - */ -object SparkLR { - val N = 10000 // Number of data points - val D = 10 // Numer of dimensions - val R = 0.7 // Scaling factor - val ITERATIONS = 5 - val rand = new Random(42) - - case class DataPoint(x: Vector, y: Double) - - def generateData = { - def generatePoint(i: Int) = { - val y = if(i % 2 == 0) -1 else 1 - val x = Vector(D, _ => rand.nextGaussian + y * R) - DataPoint(x, y) - } - Array.tabulate(N)(generatePoint) - } - - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkLR []") - System.exit(1) - } - val sc = new SparkContext(args(0), "SparkLR", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val numSlices = if (args.length > 1) args(1).toInt else 2 - val points = sc.parallelize(generateData, numSlices).cache() - - // Initialize w to a random value - var w = Vector(D, _ => 2 * rand.nextDouble - 1) - println("Initial w: " + w) - - for (i <- 1 to ITERATIONS) { - println("On iteration " + i) - val gradient = points.map { p => - (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x - }.reduce(_ + _) - w -= gradient - } - - println("Final w: " + w) - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SparkPageRank.scala b/examples/src/main/scala/spark/examples/SparkPageRank.scala deleted file mode 100644 index dedbbd01a3..0000000000 --- a/examples/src/main/scala/spark/examples/SparkPageRank.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.examples - -import spark.SparkContext._ -import spark.SparkContext - - -/** - * Computes the PageRank of URLs from an input file. Input file should - * be in format of: - * URL neighbor URL - * URL neighbor URL - * URL neighbor URL - * ... - * where URL and their neighbors are separated by space(s). - */ -object SparkPageRank { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: PageRank ") - System.exit(1) - } - var iters = args(2).toInt - val ctx = new SparkContext(args(0), "PageRank", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val lines = ctx.textFile(args(1), 1) - val links = lines.map{ s => - val parts = s.split("\\s+") - (parts(0), parts(1)) - }.distinct().groupByKey().cache() - var ranks = links.mapValues(v => 1.0) - - for (i <- 1 to iters) { - val contribs = links.join(ranks).values.flatMap{ case (urls, rank) => - val size = urls.size - urls.map(url => (url, rank / size)) - } - ranks = contribs.reduceByKey(_ + _).mapValues(0.15 + 0.85 * _) - } - - val output = ranks.collect() - output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + ".")) - - System.exit(0) - } -} - diff --git a/examples/src/main/scala/spark/examples/SparkPi.scala b/examples/src/main/scala/spark/examples/SparkPi.scala deleted file mode 100644 index 00560ac9d1..0000000000 --- a/examples/src/main/scala/spark/examples/SparkPi.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import scala.math.random -import spark._ -import SparkContext._ - -/** Computes an approximation to pi */ -object SparkPi { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkPi []") - System.exit(1) - } - val spark = new SparkContext(args(0), "SparkPi", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val slices = if (args.length > 1) args(1).toInt else 2 - val n = 100000 * slices - val count = spark.parallelize(1 to n, slices).map { i => - val x = random * 2 - 1 - val y = random * 2 - 1 - if (x*x + y*y < 1) 1 else 0 - }.reduce(_ + _) - println("Pi is roughly " + 4.0 * count / n) - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SparkTC.scala b/examples/src/main/scala/spark/examples/SparkTC.scala deleted file mode 100644 index bf988a953b..0000000000 --- a/examples/src/main/scala/spark/examples/SparkTC.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark._ -import SparkContext._ -import scala.util.Random -import scala.collection.mutable - -/** - * Transitive closure on a graph. - */ -object SparkTC { - val numEdges = 200 - val numVertices = 100 - val rand = new Random(42) - - def generateGraph = { - val edges: mutable.Set[(Int, Int)] = mutable.Set.empty - while (edges.size < numEdges) { - val from = rand.nextInt(numVertices) - val to = rand.nextInt(numVertices) - if (from != to) edges.+=((from, to)) - } - edges.toSeq - } - - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkTC []") - System.exit(1) - } - val spark = new SparkContext(args(0), "SparkTC", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val slices = if (args.length > 1) args(1).toInt else 2 - var tc = spark.parallelize(generateGraph, slices).cache() - - // Linear transitive closure: each round grows paths by one edge, - // by joining the graph's edges with the already-discovered paths. - // e.g. join the path (y, z) from the TC with the edge (x, y) from - // the graph to obtain the path (x, z). - - // Because join() joins on keys, the edges are stored in reversed order. - val edges = tc.map(x => (x._2, x._1)) - - // This join is iterated until a fixed point is reached. - var oldCount = 0L - var nextCount = tc.count() - do { - oldCount = nextCount - // Perform the join, obtaining an RDD of (y, (z, x)) pairs, - // then project the result to obtain the new (x, z) paths. - tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache(); - nextCount = tc.count() - } while (nextCount != oldCount) - - println("TC has " + tc.count() + " edges.") - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala deleted file mode 100644 index c23ee9895f..0000000000 --- a/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples.bagel - -import spark._ -import spark.SparkContext._ - -import spark.bagel._ -import spark.bagel.Bagel._ - -import scala.collection.mutable.ArrayBuffer - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} - -import com.esotericsoftware.kryo._ - -class PageRankUtils extends Serializable { - def computeWithCombiner(numVertices: Long, epsilon: Double)( - self: PRVertex, messageSum: Option[Double], superstep: Int - ): (PRVertex, Array[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = superstep >= 10 - - val outbox: Array[PRMessage] = - if (!terminate) - self.outEdges.map(targetId => - new PRMessage(targetId, newValue / self.outEdges.size)) - else - Array[PRMessage]() - - (new PRVertex(newValue, self.outEdges, !terminate), outbox) - } - - def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) = - computeWithCombiner(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -class PRCombiner extends Combiner[PRMessage, Double] with Serializable { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b -} - -class PRVertex() extends Vertex with Serializable { - var value: Double = _ - var outEdges: Array[String] = _ - var active: Boolean = _ - - def this(value: Double, outEdges: Array[String], active: Boolean = true) { - this() - this.value = value - this.outEdges = outEdges - this.active = active - } - - override def toString(): String = { - "PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString) - } -} - -class PRMessage() extends Message[String] with Serializable { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -class PRKryoRegistrator extends KryoRegistrator { - def registerClasses(kryo: Kryo) { - kryo.register(classOf[PRVertex]) - kryo.register(classOf[PRMessage]) - } -} - -class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions - - def getPartition(key: Any): Int = { - val hash = key match { - case k: Long => (k & 0x00000000FFFFFFFFL).toInt - case _ => key.hashCode - } - - val mod = key.hashCode % partitions - if (mod < 0) mod + partitions else mod - } - - override def equals(other: Any): Boolean = other match { - case c: CustomPartitioner => - c.numPartitions == numPartitions - case _ => false - } -} diff --git a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala deleted file mode 100644 index 00635a7ffa..0000000000 --- a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples.bagel - -import spark._ -import spark.SparkContext._ - -import spark.bagel._ -import spark.bagel.Bagel._ - -import scala.xml.{XML,NodeSeq} - -/** - * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" - * files from there, which contains one line per wiki article in a tab-separated format - * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). - */ -object WikipediaPageRank { - def main(args: Array[String]) { - if (args.length < 5) { - System.err.println("Usage: WikipediaPageRank ") - System.exit(-1) - } - - System.setProperty("spark.serializer", "spark.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) - - val inputFile = args(0) - val threshold = args(1).toDouble - val numPartitions = args(2).toInt - val host = args(3) - val usePartitioner = args(4).toBoolean - val sc = new SparkContext(host, "WikipediaPageRank") - - // Parse the Wikipedia page data into a graph - val input = sc.textFile(inputFile) - - println("Counting vertices...") - val numVertices = input.count() - println("Done counting vertices.") - - println("Parsing input file...") - var vertices = input.map(line => { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val links = - if (body == "\\N") - NodeSeq.Empty - else - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) - NodeSeq.Empty - } - val outEdges = links.map(link => new String(link.text)).toArray - val id = new String(title) - (id, new PRVertex(1.0 / numVertices, outEdges)) - }) - if (usePartitioner) - vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache - else - vertices = vertices.cache - println("Done parsing input file.") - - // Do the computation - val epsilon = 0.01 / numVertices - val messages = sc.parallelize(Array[(String, PRMessage)]()) - val utils = new PageRankUtils - val result = - Bagel.run( - sc, vertices, messages, combiner = new PRCombiner(), - numPartitions = numPartitions)( - utils.computeWithCombiner(numVertices, epsilon)) - - // Print the result - System.err.println("Articles with PageRank >= "+threshold+":") - val top = - (result - .filter { case (id, vertex) => vertex.value >= threshold } - .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } - .collect.mkString) - println(top) - } -} diff --git a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala deleted file mode 100644 index c416ddbc58..0000000000 --- a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples.bagel - -import spark._ -import serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import spark.SparkContext._ - -import spark.bagel._ -import spark.bagel.Bagel._ - -import scala.xml.{XML,NodeSeq} - -import scala.collection.mutable.ArrayBuffer - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer - -object WikipediaPageRankStandalone { - def main(args: Array[String]) { - if (args.length < 5) { - System.err.println("Usage: WikipediaPageRankStandalone ") - System.exit(-1) - } - - System.setProperty("spark.serializer", "spark.bagel.examples.WPRSerializer") - - val inputFile = args(0) - val threshold = args(1).toDouble - val numIterations = args(2).toInt - val host = args(3) - val usePartitioner = args(4).toBoolean - val sc = new SparkContext(host, "WikipediaPageRankStandalone") - - val input = sc.textFile(inputFile) - val partitioner = new HashPartitioner(sc.defaultParallelism) - val links = - if (usePartitioner) - input.map(parseArticle _).partitionBy(partitioner).cache() - else - input.map(parseArticle _).cache() - val n = links.count() - val defaultRank = 1.0 / n - val a = 0.15 - - // Do the computation - val startTime = System.currentTimeMillis - val ranks = - pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, sc.defaultParallelism) - - // Print the result - System.err.println("Articles with PageRank >= "+threshold+":") - val top = - (ranks - .filter { case (id, rank) => rank >= threshold } - .map { case (id, rank) => "%s\t%s\n".format(id, rank) } - .collect().mkString) - println(top) - - val time = (System.currentTimeMillis - startTime) / 1000.0 - println("Completed %d iterations in %f seconds: %f seconds per iteration" - .format(numIterations, time, time / numIterations)) - System.exit(0) - } - - def parseArticle(line: String): (String, Array[String]) = { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val id = new String(title) - val links = - if (body == "\\N") - NodeSeq.Empty - else - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) - NodeSeq.Empty - } - val outEdges = links.map(link => new String(link.text)).toArray - (id, outEdges) - } - - def pageRank( - links: RDD[(String, Array[String])], - numIterations: Int, - defaultRank: Double, - a: Double, - n: Long, - partitioner: Partitioner, - usePartitioner: Boolean, - numPartitions: Int - ): RDD[(String, Double)] = { - var ranks = links.mapValues { edges => defaultRank } - for (i <- 1 to numIterations) { - val contribs = links.groupWith(ranks).flatMap { - case (id, (linksWrapper, rankWrapper)) => - if (linksWrapper.length > 0) { - if (rankWrapper.length > 0) { - linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size)) - } else { - linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size)) - } - } else { - Array[(String, Double)]() - } - } - ranks = (contribs.combineByKey((x: Double) => x, - (x: Double, y: Double) => x + y, - (x: Double, y: Double) => x + y, - partitioner) - .mapValues(sum => a/n + (1-a)*sum)) - } - ranks - } -} - -class WPRSerializer extends spark.serializer.Serializer { - def newInstance(): SerializerInstance = new WPRSerializerInstance() -} - -class WPRSerializerInstance extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { - throw new UnsupportedOperationException() - } - - def deserialize[T](bytes: ByteBuffer): T = { - throw new UnsupportedOperationException() - } - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - throw new UnsupportedOperationException() - } - - def serializeStream(s: OutputStream): SerializationStream = { - new WPRSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new WPRDeserializationStream(s) - } -} - -class WPRSerializationStream(os: OutputStream) extends SerializationStream { - val dos = new DataOutputStream(os) - - def writeObject[T](t: T): SerializationStream = t match { - case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { - case links: Array[String] => { - dos.writeInt(0) // links - dos.writeUTF(id) - dos.writeInt(links.length) - for (link <- links) { - dos.writeUTF(link) - } - this - } - case rank: Double => { - dos.writeInt(1) // rank - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - case (id: String, rank: Double) => { - dos.writeInt(2) // rank without wrapper - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - - def flush() { dos.flush() } - def close() { dos.close() } -} - -class WPRDeserializationStream(is: InputStream) extends DeserializationStream { - val dis = new DataInputStream(is) - - def readObject[T](): T = { - val typeId = dis.readInt() - typeId match { - case 0 => { - val id = dis.readUTF() - val numLinks = dis.readInt() - val links = new Array[String](numLinks) - for (i <- 0 until numLinks) { - val link = dis.readUTF() - links(i) = link - } - (id, ArrayBuffer(links)).asInstanceOf[T] - } - case 1 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, ArrayBuffer(rank)).asInstanceOf[T] - } - case 2 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, rank).asInstanceOf[T] - } - } - } - - def close() { dis.close() } -} diff --git a/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala deleted file mode 100644 index 05d3176478..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import scala.collection.mutable.LinkedList -import scala.util.Random - -import akka.actor.Actor -import akka.actor.ActorRef -import akka.actor.Props -import akka.actor.actorRef2Scala - -import spark.streaming.Seconds -import spark.streaming.StreamingContext -import spark.streaming.StreamingContext.toPairDStreamFunctions -import spark.streaming.receivers.Receiver -import spark.util.AkkaUtils - -case class SubscribeReceiver(receiverActor: ActorRef) -case class UnsubscribeReceiver(receiverActor: ActorRef) - -/** - * Sends the random content to every receiver subscribed with 1/2 - * second delay. - */ -class FeederActor extends Actor { - - val rand = new Random() - var receivers: LinkedList[ActorRef] = new LinkedList[ActorRef]() - - val strings: Array[String] = Array("words ", "may ", "count ") - - def makeMessage(): String = { - val x = rand.nextInt(3) - strings(x) + strings(2 - x) - } - - /* - * A thread to generate random messages - */ - new Thread() { - override def run() { - while (true) { - Thread.sleep(500) - receivers.foreach(_ ! makeMessage) - } - } - }.start() - - def receive: Receive = { - - case SubscribeReceiver(receiverActor: ActorRef) => - println("received subscribe from %s".format(receiverActor.toString)) - receivers = LinkedList(receiverActor) ++ receivers - - case UnsubscribeReceiver(receiverActor: ActorRef) => - println("received unsubscribe from %s".format(receiverActor.toString)) - receivers = receivers.dropWhile(x => x eq receiverActor) - - } -} - -/** - * A sample actor as receiver, is also simplest. This receiver actor - * goes and subscribe to a typical publisher/feeder actor and receives - * data. - * - * @see [[spark.streaming.examples.FeederActor]] - */ -class SampleActorReceiver[T: ClassManifest](urlOfPublisher: String) -extends Actor with Receiver { - - lazy private val remotePublisher = context.actorFor(urlOfPublisher) - - override def preStart = remotePublisher ! SubscribeReceiver(context.self) - - def receive = { - case msg ⇒ context.parent ! pushBlock(msg.asInstanceOf[T]) - } - - override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) - -} - -/** - * A sample feeder actor - * - * Usage: FeederActor - * and describe the AkkaSystem that Spark Sample feeder would start on. - */ -object FeederActor { - - def main(args: Array[String]) { - if(args.length < 2){ - System.err.println( - "Usage: FeederActor \n" - ) - System.exit(1) - } - val Seq(host, port) = args.toSeq - - - val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt)._1 - val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") - - println("Feeder started as:" + feeder) - - actorSystem.awaitTermination(); - } -} - -/** - * A sample word count program demonstrating the use of plugging in - * Actor as Receiver - * Usage: ActorWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the AkkaSystem that Spark Sample feeder is running on. - * - * To run this example locally, you may run Feeder Actor as - * `$ ./run-example spark.streaming.examples.FeederActor 127.0.1.1 9999` - * and then run the example - * `$ ./run-example spark.streaming.examples.ActorWordCount local[2] 127.0.1.1 9999` - */ -object ActorWordCount { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ActorWordCount " + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - val Seq(master, host, port) = args.toSeq - - // Create the context and set the batch size - val ssc = new StreamingContext(master, "ActorWordCount", Seconds(2), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - /* - * Following is the use of actorStream to plug in custom actor as receiver - * - * An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e type of data received and InputDstream - * should be same. - * - * For example: Both actorStream and SampleActorReceiver are parameterized - * to same type to ensure type safety. - */ - - val lines = ssc.actorStream[String]( - Props(new SampleActorReceiver[String]("akka://test@%s:%s/user/FeederActor".format( - host, port.toInt))), "SampleReceiver") - - //compute wordcount - lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print() - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala deleted file mode 100644 index 3ab4fc2c37..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.util.IntParam -import spark.storage.StorageLevel -import spark.streaming._ - -/** - * Produces a count of events received from Flume. - * - * This should be used in conjunction with an AvroSink in Flume. It will start - * an Avro server on at the request host:port address and listen for requests. - * Your Flume AvroSink should be pointed to this address. - * - * Usage: FlumeEventCount - * - * is a Spark master URL - * is the host the Flume receiver will be started on - a receiver - * creates a server and listens for flume events. - * is the port the Flume receiver will listen on. - */ -object FlumeEventCount { - def main(args: Array[String]) { - if (args.length != 3) { - System.err.println( - "Usage: FlumeEventCount ") - System.exit(1) - } - - val Array(master, host, IntParam(port)) = args - - val batchInterval = Milliseconds(2000) - // Create the context and set the batch size - val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval, - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Create a flume stream - val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) - - // Print out the count of events received from this server in each batch - stream.count().map(cnt => "Received " + cnt + " flume events." ).print() - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala deleted file mode 100644 index 30af01a26f..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - - -/** - * Counts words in new text files created in the given directory - * Usage: HdfsWordCount - * is the Spark master URL. - * is the directory that Spark Streaming will use to find and read new text files. - * - * To run this on your local machine on directory `localdir`, run this example - * `$ ./run-example spark.streaming.examples.HdfsWordCount local[2] localdir` - * Then create a text file in `localdir` and the words in the file will get counted. - */ -object HdfsWordCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: HdfsWordCount ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val lines = ssc.textFileStream(args(1)) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} - diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala deleted file mode 100644 index d9c76d1a33..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import java.util.Properties -import kafka.message.Message -import kafka.producer.SyncProducerConfig -import kafka.producer._ -import spark.SparkContext -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.storage.StorageLevel -import spark.streaming.util.RawTextHelper._ - -/** - * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: KafkaWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * is a list of one or more zookeeper servers that make quorum - * is the name of kafka consumer group - * is a list of one or more kafka topics to consume from - * is the number of threads the kafka consumer should use - * - * Example: - * `./run-example spark.streaming.examples.KafkaWordCount local[2] zoo01,zoo02,zoo03 my-consumer-group topic1,topic2 1` - */ -object KafkaWordCount { - def main(args: Array[String]) { - - if (args.length < 5) { - System.err.println("Usage: KafkaWordCount ") - System.exit(1) - } - - val Array(master, zkQuorum, group, topics, numThreads) = args - - val ssc = new StreamingContext(master, "KafkaWordCount", Seconds(2), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - ssc.checkpoint("checkpoint") - - val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap - val lines = ssc.kafkaStream(zkQuorum, group, topicpMap) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) - wordCounts.print() - - ssc.start() - } -} - -// Produces some random words between 1 and 100. -object KafkaWordCountProducer { - - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: KafkaWordCountProducer ") - System.exit(1) - } - - val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args - - // Zookeper connection properties - val props = new Properties() - props.put("zk.connect", zkQuorum) - props.put("serializer.class", "kafka.serializer.StringEncoder") - - val config = new ProducerConfig(props) - val producer = new Producer[String, String](config) - - // Send some messages - while(true) { - val messages = (1 to messagesPerSec.toInt).map { messageNum => - (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ") - }.toArray - println(messages.mkString(",")) - val data = new ProducerData[String, String](topic, messages) - producer.send(data) - Thread.sleep(100) - } - } - -} - diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala deleted file mode 100644 index b29d79aac5..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -/** - * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. - * Usage: NetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the TCP server that Spark Streaming would connect to receive data. - * - * To run this on your local machine, you need to first run a Netcat server - * `$ nc -lk 9999` - * and then run the example - * `$ ./run-example spark.streaming.examples.NetworkWordCount local[2] localhost 9999` - */ -object NetworkWordCount { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: NetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - // Create the context with a 1 second batch size - val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.socketTextStream(args(1), args(2).toInt) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/spark/streaming/examples/QueueStream.scala deleted file mode 100644 index da36c8c23c..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/QueueStream.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.RDD -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -import scala.collection.mutable.SynchronizedQueue - -object QueueStream { - - def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: QueueStream ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Create the queue through which RDDs can be pushed to - // a QueueInputDStream - val rddQueue = new SynchronizedQueue[RDD[Int]]() - - // Create the QueueInputDStream and use it do some processing - val inputStream = ssc.queueStream(rddQueue) - val mappedStream = inputStream.map(x => (x % 10, 1)) - val reducedStream = mappedStream.reduceByKey(_ + _) - reducedStream.print() - ssc.start() - - // Create and push some RDDs into - for (i <- 1 to 30) { - rddQueue += ssc.sparkContext.makeRDD(1 to 1000, 10) - Thread.sleep(1000) - } - ssc.stop() - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala deleted file mode 100644 index 7fb680bcc3..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.util.IntParam -import spark.storage.StorageLevel - -import spark.streaming._ -import spark.streaming.util.RawTextHelper - -/** - * Receives text from multiple rawNetworkStreams and counts how many '\n' delimited - * lines have the word 'the' in them. This is useful for benchmarking purposes. This - * will only work with spark.streaming.util.RawTextSender running on all worker nodes - * and with Spark using Kryo serialization (set Java property "spark.serializer" to - * "spark.KryoSerializer"). - * Usage: RawNetworkGrep - * is the Spark master URL - * is the number rawNetworkStreams, which should be same as number - * of work nodes in the cluster - * is "localhost". - * is the port on which RawTextSender is running in the worker nodes. - * is the Spark Streaming batch duration in milliseconds. - */ - -object RawNetworkGrep { - def main(args: Array[String]) { - if (args.length != 5) { - System.err.println("Usage: RawNetworkGrep ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args - - // Create the context - val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - RawTextHelper.warmUp(ssc.sparkContext) - - val rawStreams = (1 to numStreams).map(_ => - ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray - val union = ssc.union(rawStreams) - union.filter(_.contains("the")).count().foreach(r => - println("Grep count: " + r.collect().mkString)) - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala deleted file mode 100644 index b709fc3c87..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming._ -import spark.streaming.StreamingContext._ - -/** - * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every second. - * Usage: StatefulNetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the TCP server that Spark Streaming would connect to receive data. - * - * To run this on your local machine, you need to first run a Netcat server - * `$ nc -lk 9999` - * and then run the example - * `$ ./run-example spark.streaming.examples.StatefulNetworkWordCount local[2] localhost 9999` - */ -object StatefulNetworkWordCount { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: StatefulNetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - val updateFunc = (values: Seq[Int], state: Option[Int]) => { - val currentCount = values.foldLeft(0)(_ + _) - - val previousCount = state.getOrElse(0) - - Some(currentCount + previousCount) - } - - // Create the context with a 1 second batch size - val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(1), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - ssc.checkpoint(".") - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.socketTextStream(args(1), args(2).toInt) - val words = lines.flatMap(_.split(" ")) - val wordDstream = words.map(x => (x, 1)) - - // Update the cumulative count using updateStateByKey - // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](updateFunc) - stateDstream.print() - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala deleted file mode 100644 index 8770abd57e..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.storage.StorageLevel -import com.twitter.algebird._ -import spark.streaming.StreamingContext._ -import spark.SparkContext._ - -/** - * Illustrates the use of the Count-Min Sketch, from Twitter's Algebird library, to compute - * windowed and global Top-K estimates of user IDs occurring in a Twitter stream. - *
    - * Note that since Algebird's implementation currently only supports Long inputs, - * the example operates on Long IDs. Once the implementation supports other inputs (such as String), - * the same approach could be used for computing popular topics for example. - *

    - *

    - * - * This blog post has a good overview of the Count-Min Sketch (CMS). The CMS is a datastructure - * for approximate frequency estimation in data streams (e.g. Top-K elements, frequency of any given element, etc), - * that uses space sub-linear in the number of elements in the stream. Once elements are added to the CMS, the - * estimated count of an element can be computed, as well as "heavy-hitters" that occur more than a threshold - * percentage of the overall total count. - *

    - * Algebird's implementation is a monoid, so we can succinctly merge two CMS instances in the reduce operation. - */ -object TwitterAlgebirdCMS { - def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterAlgebirdCMS " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } - - // CMS parameters - val DELTA = 1E-3 - val EPS = 0.01 - val SEED = 1 - val PERC = 0.001 - // K highest frequency elements to take - val TOPK = 10 - - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) - - val users = stream.map(status => status.getUser.getId) - - val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) - var globalCMS = cms.zero - val mm = new MapMonoid[Long, Int]() - var globalExact = Map[Long, Int]() - - val approxTopUsers = users.mapPartitions(ids => { - ids.map(id => cms.create(id)) - }).reduce(_ ++ _) - - val exactTopUsers = users.map(id => (id, 1)) - .reduceByKey((a, b) => a + b) - - approxTopUsers.foreach(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - val partialTopK = partial.heavyHitters.map(id => - (id, partial.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) - globalCMS ++= partial - val globalTopK = globalCMS.heavyHitters.map(id => - (id, globalCMS.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) - println("Approx heavy hitters at %2.2f%% threshold this batch: %s".format(PERC, - partialTopK.mkString("[", ",", "]"))) - println("Approx heavy hitters at %2.2f%% threshold overall: %s".format(PERC, - globalTopK.mkString("[", ",", "]"))) - } - }) - - exactTopUsers.foreach(rdd => { - if (rdd.count() != 0) { - val partialMap = rdd.collect().toMap - val partialTopK = rdd.map( - {case (id, count) => (count, id)}) - .sortByKey(ascending = false).take(TOPK) - globalExact = mm.plus(globalExact.toMap, partialMap) - val globalTopK = globalExact.toSeq.sortBy(_._2).reverse.slice(0, TOPK) - println("Exact heavy hitters this batch: %s".format(partialTopK.mkString("[", ",", "]"))) - println("Exact heavy hitters overall: %s".format(globalTopK.mkString("[", ",", "]"))) - } - }) - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala deleted file mode 100644 index cba5c986be..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.storage.StorageLevel -import com.twitter.algebird.HyperLogLog._ -import com.twitter.algebird.HyperLogLogMonoid -import spark.streaming.dstream.TwitterInputDStream - -/** - * Illustrates the use of the HyperLogLog algorithm, from Twitter's Algebird library, to compute - * a windowed and global estimate of the unique user IDs occurring in a Twitter stream. - *

    - *

    - * This - * blog post and this - * blog post - * have good overviews of HyperLogLog (HLL). HLL is a memory-efficient datastructure for estimating - * the cardinality of a data stream, i.e. the number of unique elements. - *

    - * Algebird's implementation is a monoid, so we can succinctly merge two HLL instances in the reduce operation. - */ -object TwitterAlgebirdHLL { - def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterAlgebirdHLL " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } - - /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ - val BIT_SIZE = 12 - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) - - val users = stream.map(status => status.getUser.getId) - - val hll = new HyperLogLogMonoid(BIT_SIZE) - var globalHll = hll.zero - var userSet: Set[Long] = Set() - - val approxUsers = users.mapPartitions(ids => { - ids.map(id => hll(id)) - }).reduce(_ + _) - - val exactUsers = users.map(id => Set(id)).reduce(_ ++ _) - - approxUsers.foreach(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - globalHll += partial - println("Approx distinct users this batch: %d".format(partial.estimatedSize.toInt)) - println("Approx distinct users overall: %d".format(globalHll.estimatedSize.toInt)) - } - }) - - exactUsers.foreach(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - userSet ++= partial - println("Exact distinct users this batch: %d".format(partial.size)) - println("Exact distinct users overall: %d".format(userSet.size)) - println("Error rate: %2.5f%%".format(((globalHll.estimatedSize / userSet.size.toDouble) - 1) * 100)) - } - }) - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala deleted file mode 100644 index 682b99f75e..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import StreamingContext._ -import spark.SparkContext._ - -/** - * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter - * stream. The stream is instantiated with credentials and optionally filters supplied by the - * command line arguments. - * - */ -object TwitterPopularTags { - def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterPopularTags " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } - - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(None, filters) - - val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) - - val topCounts60 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) - .map{case (topic, count) => (count, topic)} - .transform(_.sortByKey(false)) - - val topCounts10 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(10)) - .map{case (topic, count) => (count, topic)} - .transform(_.sortByKey(false)) - - - // Print popular hashtags - topCounts60.foreach(rdd => { - val topList = rdd.take(5) - println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - }) - - topCounts10.foreach(rdd => { - val topList = rdd.take(5) - println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - }) - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala deleted file mode 100644 index a0cae06c30..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import akka.actor.ActorSystem -import akka.actor.actorRef2Scala -import akka.zeromq._ -import spark.streaming.{ Seconds, StreamingContext } -import spark.streaming.StreamingContext._ -import akka.zeromq.Subscribe - -/** - * A simple publisher for demonstration purposes, repeatedly publishes random Messages - * every one second. - */ -object SimpleZeroMQPublisher { - - def main(args: Array[String]) = { - if (args.length < 2) { - System.err.println("Usage: SimpleZeroMQPublisher ") - System.exit(1) - } - - val Seq(url, topic) = args.toSeq - val acs: ActorSystem = ActorSystem() - - val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - val messages: Array[String] = Array("words ", "may ", "count ") - while (true) { - Thread.sleep(1000) - pubSocket ! ZMQMessage(Frame(topic) :: messages.map(x => Frame(x.getBytes)).toList) - } - acs.awaitTermination() - } -} - -/** - * A sample wordcount with ZeroMQStream stream - * - * To work with zeroMQ, some native libraries have to be installed. - * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide](http://www.zeromq.org/intro:get-the-software) - * - * Usage: ZeroMQWordCount - * In local mode, should be 'local[n]' with n > 1 - * and describe where zeroMq publisher is running. - * - * To run this example locally, you may run publisher as - * `$ ./run-example spark.streaming.examples.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` - * and run the example as - * `$ ./run-example spark.streaming.examples.ZeroMQWordCount local[2] tcp://127.0.1.1:1234 foo` - */ -object ZeroMQWordCount { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ZeroMQWordCount " + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - val Seq(master, url, topic) = args.toSeq - - // Create the context and set the batch size - val ssc = new StreamingContext(master, "ZeroMQWordCount", Seconds(2), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - def bytesToStringIterator(x: Seq[Seq[Byte]]) = (x.map(x => new String(x.toArray))).iterator - - //For this stream, a zeroMQ publisher should be running. - val lines = ssc.zeroMQStream(url, Subscribe(topic), bytesToStringIterator) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } - -} diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala deleted file mode 100644 index dd36bbbf32..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples.clickstream - -import java.net.{InetAddress,ServerSocket,Socket,SocketException} -import java.io.{InputStreamReader, BufferedReader, PrintWriter} -import util.Random - -/** Represents a page view on a website with associated dimension data.*/ -class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) { - override def toString() : String = { - "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID) - } -} -object PageView { - def fromString(in : String) : PageView = { - val parts = in.split("\t") - new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt) - } -} - -/** Generates streaming events to simulate page views on a website. - * - * This should be used in tandem with PageViewStream.scala. Example: - * $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10 - * $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 - * */ -object PageViewGenerator { - val pages = Map("http://foo.com/" -> .7, - "http://foo.com/news" -> 0.2, - "http://foo.com/contact" -> .1) - val httpStatus = Map(200 -> .95, - 404 -> .05) - val userZipCode = Map(94709 -> .5, - 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01):_*) - - - def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { - val rand = new Random().nextDouble() - var total = 0.0 - for ((item, prob) <- inputMap) { - total = total + prob - if (total > rand) { - return item - } - } - return inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 - } - - def getNextClickEvent() : String = { - val id = pickFromDistribution(userID) - val page = pickFromDistribution(pages) - val status = pickFromDistribution(httpStatus) - val zipCode = pickFromDistribution(userZipCode) - new PageView(page, status, zipCode, id).toString() - } - - def main(args : Array[String]) { - if (args.length != 2) { - System.err.println("Usage: PageViewGenerator ") - System.exit(1) - } - val port = args(0).toInt - val viewsPerSecond = args(1).toFloat - val sleepDelayMs = (1000.0 / viewsPerSecond).toInt - val listener = new ServerSocket(port) - println("Listening on port: " + port) - - while (true) { - val socket = listener.accept() - new Thread() { - override def run = { - println("Got client connected from: " + socket.getInetAddress) - val out = new PrintWriter(socket.getOutputStream(), true) - - while (true) { - Thread.sleep(sleepDelayMs) - out.write(getNextClickEvent()) - out.flush() - } - socket.close() - } - }.start() - } - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala deleted file mode 100644 index 152da23489..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples.clickstream - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ -import spark.SparkContext._ - -/** Analyses a streaming dataset of web page views. This class demonstrates several types of - * operators available in Spark streaming. - * - * This should be used in tandem with PageViewStream.scala. Example: - * $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10 - * $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 - */ -object PageViewStream { - def main(args: Array[String]) { - if (args.length != 3) { - System.err.println("Usage: PageViewStream ") - System.err.println(" must be one of pageCounts, slidingPageCounts," + - " errorRatePerZipCode, activeUserCount, popularUsersSeen") - System.exit(1) - } - val metric = args(0) - val host = args(1) - val port = args(2).toInt - - // Create the context - val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Create a NetworkInputDStream on target host:port and convert each line to a PageView - val pageViews = ssc.socketTextStream(host, port) - .flatMap(_.split("\n")) - .map(PageView.fromString(_)) - - // Return a count of views per URL seen in each batch - val pageCounts = pageViews.map(view => view.url).countByValue() - - // Return a sliding window of page views per URL in the last ten seconds - val slidingPageCounts = pageViews.map(view => view.url) - .countByValueAndWindow(Seconds(10), Seconds(2)) - - - // Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds - val statusesPerZipCode = pageViews.window(Seconds(30), Seconds(2)) - .map(view => ((view.zipCode, view.status))) - .groupByKey() - val errorRatePerZipCode = statusesPerZipCode.map{ - case(zip, statuses) => - val normalCount = statuses.filter(_ == 200).size - val errorCount = statuses.size - normalCount - val errorRatio = errorCount.toFloat / statuses.size - if (errorRatio > 0.05) {"%s: **%s**".format(zip, errorRatio)} - else {"%s: %s".format(zip, errorRatio)} - } - - // Return the number unique users in last 15 seconds - val activeUserCount = pageViews.window(Seconds(15), Seconds(2)) - .map(view => (view.userID, 1)) - .groupByKey() - .count() - .map("Unique active users: " + _) - - // An external dataset we want to join to this stream - val userList = ssc.sparkContext.parallelize( - Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq) - - metric match { - case "pageCounts" => pageCounts.print() - case "slidingPageCounts" => slidingPageCounts.print() - case "errorRatePerZipCode" => errorRatePerZipCode.print() - case "activeUserCount" => activeUserCount.print() - case "popularUsersSeen" => - // Look for users in our existing dataset and print it out if we have a match - pageViews.map(view => (view.userID, 1)) - .foreach((rdd, time) => rdd.join(userList) - .map(_._2._2) - .take(10) - .foreach(u => println("Saw user %s at time %s".format(u, time)))) - case _ => println("Invalid metric entered: " + metric) - } - - ssc.start() - } -} diff --git a/mllib/pom.xml b/mllib/pom.xml index ab31d5734e..2d5d3c00d1 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-mllib jar Spark Project ML Library @@ -33,7 +33,7 @@ - org.spark-project + org.apache.spark spark-core ${project.version} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala new file mode 100644 index 0000000000..4f4a7f5296 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -0,0 +1,21 @@ +package org.apache.spark.mllib.classification + +import org.apache.spark.RDD + +trait ClassificationModel extends Serializable { + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(testData: RDD[Array[Double]]): RDD[Double] + + /** + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Int prediction from the trained model + */ + def predict(testData: Array[Double]): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala new file mode 100644 index 0000000000..91bb50c829 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.math.round + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.DataValidators + +import org.jblas.DoubleMatrix + +/** + * Classification model trained using Logistic Regression. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class LogisticRegressionModel( + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) + with ClassificationModel with Serializable { + + override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double) = { + val margin = dataMatrix.mmul(weightMatrix).get(0) + intercept + round(1.0/ (1.0 + math.exp(margin * -1))) + } +} + +/** + * Train a classification model for Logistic Regression using Stochastic Gradient Descent. + * NOTE: Labels used in Logistic Regression should be {0, 1} + */ +class LogisticRegressionWithSGD private ( + var stepSize: Double, + var numIterations: Int, + var regParam: Double, + var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[LogisticRegressionModel] + with Serializable { + + val gradient = new LogisticGradient() + val updater = new SimpleUpdater() + override val optimizer = new GradientDescent(gradient, updater) + .setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + override val validators = List(DataValidators.classificationLabels) + + /** + * Construct a LogisticRegression object with default parameters + */ + def this() = this(1.0, 100, 0.0, 1.0) + + def createModel(weights: Array[Double], intercept: Double) = { + new LogisticRegressionModel(weights, intercept) + } +} + +/** + * Top-level methods for calling Logistic Regression. + * NOTE: Labels used in Logistic Regression should be {0, 1} + */ +object LogisticRegressionWithSGD { + // NOTE(shivaram): We use multiple train methods instead of default arguments to support + // Java programs. + + /** + * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed + * number of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in + * gradient descent are initialized using the initial weights provided. + * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param miniBatchFraction Fraction of data to be used per iteration. + * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * the number of features in the data. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) + : LogisticRegressionModel = + { + new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( + input, initialWeights) + } + + /** + * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed + * number of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. + * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + + * @param miniBatchFraction Fraction of data to be used per iteration. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double) + : LogisticRegressionModel = + { + new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( + input) + } + + /** + * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed + * number of iterations of gradient descent using the specified step size. We use the entire data + * set to update the gradient in each iteration. + * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param stepSize Step size to be used for each iteration of Gradient Descent. + + * @param numIterations Number of iterations of gradient descent to run. + * @return a LogisticRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double) + : LogisticRegressionModel = + { + train(input, numIterations, stepSize, 1.0) + } + + /** + * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed + * number of iterations of gradient descent using a step size of 1.0. We use the entire data set + * to update the gradient in each iteration. + * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @return a LogisticRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int) + : LogisticRegressionModel = + { + train(input, numIterations, 1.0, 1.0) + } + + def main(args: Array[String]) { + if (args.length != 4) { + println("Usage: LogisticRegression " + + "") + System.exit(1) + } + val sc = new SparkContext(args(0), "LogisticRegression") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala new file mode 100644 index 0000000000..c92c7cc3f3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.math.signum + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.DataValidators + +import org.jblas.DoubleMatrix + +/** + * Model built using SVM. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class SVMModel( + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) + with ClassificationModel with Serializable { + + override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double) = { + val margin = dataMatrix.dot(weightMatrix) + intercept + if (margin < 0) 0.0 else 1.0 + } +} + +/** + * Train an SVM using Stochastic Gradient Descent. + * NOTE: Labels used in SVM should be {0, 1} + */ +class SVMWithSGD private ( + var stepSize: Double, + var numIterations: Int, + var regParam: Double, + var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[SVMModel] with Serializable { + + val gradient = new HingeGradient() + val updater = new SquaredL2Updater() + override val optimizer = new GradientDescent(gradient, updater) + .setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + + override val validators = List(DataValidators.classificationLabels) + + /** + * Construct a SVM object with default parameters + */ + def this() = this(1.0, 100, 1.0, 1.0) + + def createModel(weights: Array[Double], intercept: Double) = { + new SVMModel(weights, intercept) + } +} + +/** + * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1} + */ +object SVMWithSGD { + + /** + * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in + * gradient descent are initialized using the initial weights provided. + * NOTE: Labels used in SVM should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * the number of features in the data. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) + : SVMModel = + { + new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, + initialWeights) + } + + /** + * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. + * NOTE: Labels used in SVM should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double) + : SVMModel = + { + new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) + } + + /** + * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. We use the entire data set to + * update the gradient in each iteration. + * NOTE: Labels used in SVM should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param stepSize Step size to be used for each iteration of Gradient Descent. + * @param regParam Regularization parameter. + * @param numIterations Number of iterations of gradient descent to run. + * @return a SVMModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double) + : SVMModel = + { + train(input, numIterations, stepSize, regParam, 1.0) + } + + /** + * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using a step size of 1.0. We use the entire data set to + * update the gradient in each iteration. + * NOTE: Labels used in SVM should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @return a SVMModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int) + : SVMModel = + { + train(input, numIterations, 1.0, 1.0, 1.0) + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: SVM ") + System.exit(1) + } + val sc = new SparkContext(args(0), "SVM") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala new file mode 100644 index 0000000000..2c3db099fa --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.apache.spark.{SparkContext, RDD} +import org.apache.spark.SparkContext._ +import org.apache.spark.Logging +import org.apache.spark.mllib.util.MLUtils + +import org.jblas.DoubleMatrix + + +/** + * K-means clustering with support for multiple parallel runs and a k-means++ like initialization + * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, + * they are executed together with joint passes over the data for efficiency. + * + * This is an iterative algorithm that will make multiple passes over the data, so any RDDs given + * to it should be cached by the user. + */ +class KMeans private ( + var k: Int, + var maxIterations: Int, + var runs: Int, + var initializationMode: String, + var initializationSteps: Int, + var epsilon: Double) + extends Serializable with Logging +{ + private type ClusterCenters = Array[Array[Double]] + + def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) + + /** Set the number of clusters to create (k). Default: 2. */ + def setK(k: Int): KMeans = { + this.k = k + this + } + + /** Set maximum number of iterations to run. Default: 20. */ + def setMaxIterations(maxIterations: Int): KMeans = { + this.maxIterations = maxIterations + this + } + + /** + * Set the initialization algorithm. This can be either "random" to choose random points as + * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ + * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. + */ + def setInitializationMode(initializationMode: String): KMeans = { + if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) { + throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode) + } + this.initializationMode = initializationMode + this + } + + /** + * Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm + * this many times with random starting conditions (configured by the initialization mode), then + * return the best clustering found over any run. Default: 1. + */ + def setRuns(runs: Int): KMeans = { + if (runs <= 0) { + throw new IllegalArgumentException("Number of runs must be positive") + } + this.runs = runs + this + } + + /** + * Set the number of steps for the k-means|| initialization mode. This is an advanced + * setting -- the default of 5 is almost always enough. Default: 5. + */ + def setInitializationSteps(initializationSteps: Int): KMeans = { + if (initializationSteps <= 0) { + throw new IllegalArgumentException("Number of initialization steps must be positive") + } + this.initializationSteps = initializationSteps + this + } + + /** + * Set the distance threshold within which we've consider centers to have converged. + * If all centers move less than this Euclidean distance, we stop iterating one run. + */ + def setEpsilon(epsilon: Double): KMeans = { + this.epsilon = epsilon + this + } + + /** + * Train a K-means model on the given set of points; `data` should be cached for high + * performance, because this is an iterative algorithm. + */ + def run(data: RDD[Array[Double]]): KMeansModel = { + // TODO: check whether data is persistent; this needs RDD.storageLevel to be publicly readable + + val sc = data.sparkContext + + val centers = if (initializationMode == KMeans.RANDOM) { + initRandom(data) + } else { + initKMeansParallel(data) + } + + val active = Array.fill(runs)(true) + val costs = Array.fill(runs)(0.0) + + var activeRuns = new ArrayBuffer[Int] ++ (0 until runs) + var iteration = 0 + + // Execute iterations of Lloyd's algorithm until all runs have converged + while (iteration < maxIterations && !activeRuns.isEmpty) { + type WeightedPoint = (DoubleMatrix, Long) + def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = { + (p1._1.addi(p2._1), p1._2 + p2._2) + } + + val activeCenters = activeRuns.map(r => centers(r)).toArray + val costAccums = activeRuns.map(_ => sc.accumulator(0.0)) + + // Find the sum and count of points mapping to each center + val totalContribs = data.mapPartitions { points => + val runs = activeCenters.length + val k = activeCenters(0).length + val dims = activeCenters(0)(0).length + + val sums = Array.fill(runs, k)(new DoubleMatrix(dims)) + val counts = Array.fill(runs, k)(0L) + + for (point <- points; (centers, runIndex) <- activeCenters.zipWithIndex) { + val (bestCenter, cost) = KMeans.findClosest(centers, point) + costAccums(runIndex) += cost + sums(runIndex)(bestCenter).addi(new DoubleMatrix(point)) + counts(runIndex)(bestCenter) += 1 + } + + val contribs = for (i <- 0 until runs; j <- 0 until k) yield { + ((i, j), (sums(i)(j), counts(i)(j))) + } + contribs.iterator + }.reduceByKey(mergeContribs).collectAsMap() + + // Update the cluster centers and costs for each active run + for ((run, i) <- activeRuns.zipWithIndex) { + var changed = false + for (j <- 0 until k) { + val (sum, count) = totalContribs((i, j)) + if (count != 0) { + val newCenter = sum.divi(count).data + if (MLUtils.squaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) { + changed = true + } + centers(run)(j) = newCenter + } + } + if (!changed) { + active(run) = false + logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations") + } + costs(run) = costAccums(i).value + } + + activeRuns = activeRuns.filter(active(_)) + iteration += 1 + } + + val bestRun = costs.zipWithIndex.min._2 + new KMeansModel(centers(bestRun)) + } + + /** + * Initialize `runs` sets of cluster centers at random. + */ + private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = { + // Sample all the cluster centers in one pass to avoid repeated scans + val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq + Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray) + } + + /** + * Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al. + * (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries + * to find with dissimilar cluster centers by starting with a random center and then doing + * passes where more centers are chosen with probability proportional to their squared distance + * to the current cluster set. It results in a provable approximation to an optimal clustering. + * + * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. + */ + private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = { + // Initialize each run's center to a random point + val seed = new Random().nextInt() + val sample = data.takeSample(true, runs, seed).toSeq + val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r))) + + // On each step, sample 2 * k points on average for each run with probability proportional + // to their squared distance from that run's current centers + for (step <- 0 until initializationSteps) { + val centerArrays = centers.map(_.toArray) + val sumCosts = data.flatMap { point => + for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point)) + }.reduceByKey(_ + _).collectAsMap() + val chosen = data.mapPartitionsWithIndex { (index, points) => + val rand = new Random(seed ^ (step << 16) ^ index) + for { + p <- points + r <- 0 until runs + if rand.nextDouble() < KMeans.pointCost(centerArrays(r), p) * 2 * k / sumCosts(r) + } yield (r, p) + }.collect() + for ((r, p) <- chosen) { + centers(r) += p + } + } + + // Finally, we might have a set of more than k candidate centers for each run; weigh each + // candidate by the number of points in the dataset mapping to it and run a local k-means++ + // on the weighted centers to pick just k of them + val centerArrays = centers.map(_.toArray) + val weightMap = data.flatMap { p => + for (r <- 0 until runs) yield ((r, KMeans.findClosest(centerArrays(r), p)._1), 1.0) + }.reduceByKey(_ + _).collectAsMap() + val finalCenters = (0 until runs).map { r => + val myCenters = centers(r).toArray + val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray + LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) + } + + finalCenters.toArray + } +} + + +/** + * Top-level methods for calling K-means clustering. + */ +object KMeans { + // Initialization mode names + val RANDOM = "random" + val K_MEANS_PARALLEL = "k-means||" + + def train( + data: RDD[Array[Double]], + k: Int, + maxIterations: Int, + runs: Int, + initializationMode: String) + : KMeansModel = + { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setRuns(runs) + .setInitializationMode(initializationMode) + .run(data) + } + + def train(data: RDD[Array[Double]], k: Int, maxIterations: Int, runs: Int): KMeansModel = { + train(data, k, maxIterations, runs, K_MEANS_PARALLEL) + } + + def train(data: RDD[Array[Double]], k: Int, maxIterations: Int): KMeansModel = { + train(data, k, maxIterations, 1, K_MEANS_PARALLEL) + } + + /** + * Return the index of the closest point in `centers` to `point`, as well as its distance. + */ + private[mllib] def findClosest(centers: Array[Array[Double]], point: Array[Double]) + : (Int, Double) = + { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + for (i <- 0 until centers.length) { + val distance = MLUtils.squaredDistance(point, centers(i)) + if (distance < bestDistance) { + bestDistance = distance + bestIndex = i + } + } + (bestIndex, bestDistance) + } + + /** + * Return the K-means cost of a given point against the given cluster centers. + */ + private[mllib] def pointCost(centers: Array[Array[Double]], point: Array[Double]): Double = { + var bestDistance = Double.PositiveInfinity + for (i <- 0 until centers.length) { + val distance = MLUtils.squaredDistance(point, centers(i)) + if (distance < bestDistance) { + bestDistance = distance + } + } + bestDistance + } + + def main(args: Array[String]) { + if (args.length < 4) { + println("Usage: KMeans []") + System.exit(1) + } + val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt) + val runs = if (args.length >= 5) args(4).toInt else 1 + val sc = new SparkContext(master, "KMeans") + val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble)).cache() + val model = KMeans.train(data, k, iters, runs) + val cost = model.computeCost(data) + println("Cluster centers:") + for (c <- model.clusterCenters) { + println(" " + c.mkString(" ")) + } + println("Cost: " + cost) + System.exit(0) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala new file mode 100644 index 0000000000..d1fe5d138d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.RDD +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.MLUtils + + +/** + * A clustering model for K-means. Each point belongs to the cluster with the closest center. + */ +class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable { + /** Total number of clusters. */ + def k: Int = clusterCenters.length + + /** Return the cluster index that a given point belongs to. */ + def predict(point: Array[Double]): Int = { + KMeans.findClosest(clusterCenters, point)._1 + } + + /** + * Return the K-means cost (sum of squared distances of points to their nearest center) for this + * model on the given data. + */ + def computeCost(data: RDD[Array[Double]]): Double = { + data.map(p => KMeans.pointCost(clusterCenters, p)).sum + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala new file mode 100644 index 0000000000..baf8251d8f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.util.Random + +import org.jblas.{DoubleMatrix, SimpleBlas} + +/** + * An utility object to run K-means locally. This is private to the ML package because it's used + * in the initialization of KMeans but not meant to be publicly exposed. + */ +private[mllib] object LocalKMeans { + /** + * Run K-means++ on the weighted point set `points`. This first does the K-means++ + * initialization procedure and then roudns of Lloyd's algorithm. + */ + def kMeansPlusPlus( + seed: Int, + points: Array[Array[Double]], + weights: Array[Double], + k: Int, + maxIterations: Int) + : Array[Array[Double]] = + { + val rand = new Random(seed) + val dimensions = points(0).length + val centers = new Array[Array[Double]](k) + + // Initialize centers by sampling using the k-means++ procedure + centers(0) = pickWeighted(rand, points, weights) + for (i <- 1 until k) { + // Pick the next center with a probability proportional to cost under current centers + val curCenters = centers.slice(0, i) + val sum = points.zip(weights).map { case (p, w) => + w * KMeans.pointCost(curCenters, p) + }.sum + val r = rand.nextDouble() * sum + var cumulativeScore = 0.0 + var j = 0 + while (j < points.length && cumulativeScore < r) { + cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j)) + j += 1 + } + centers(i) = points(j-1) + } + + // Run up to maxIterations iterations of Lloyd's algorithm + val oldClosest = Array.fill(points.length)(-1) + var iteration = 0 + var moved = true + while (moved && iteration < maxIterations) { + moved = false + val sums = Array.fill(k)(new DoubleMatrix(dimensions)) + val counts = Array.fill(k)(0.0) + for ((p, i) <- points.zipWithIndex) { + val index = KMeans.findClosest(centers, p)._1 + SimpleBlas.axpy(weights(i), new DoubleMatrix(p), sums(index)) + counts(index) += weights(i) + if (index != oldClosest(i)) { + moved = true + oldClosest(i) = index + } + } + // Update centers + for (i <- 0 until k) { + if (counts(i) == 0.0) { + // Assign center to a random point + centers(i) = points(rand.nextInt(points.length)) + } else { + centers(i) = sums(i).divi(counts(i)).data + } + } + iteration += 1 + } + + centers + } + + private def pickWeighted[T](rand: Random, data: Array[T], weights: Array[Double]): T = { + val r = rand.nextDouble() * weights.sum + var i = 0 + var curWeight = 0.0 + while (i < data.length && curWeight < r) { + curWeight += weights(i) + i += 1 + } + data(i - 1) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala new file mode 100644 index 0000000000..749e7364f4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import org.jblas.DoubleMatrix + +/** + * Class used to compute the gradient for a loss function, given a single data point. + */ +abstract class Gradient extends Serializable { + /** + * Compute the gradient and loss given features of a single data point. + * + * @param data - Feature values for one data point. Column matrix of size nx1 + * where n is the number of features. + * @param label - Label for this data item. + * @param weights - Column matrix containing weights for every feature. + * + * @return A tuple of 2 elements. The first element is a column matrix containing the computed + * gradient and the second element is the loss computed at this data point. + * + */ + def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): + (DoubleMatrix, Double) +} + +/** + * Compute gradient and loss for a logistic loss function. + */ +class LogisticGradient extends Gradient { + override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): + (DoubleMatrix, Double) = { + val margin: Double = -1.0 * data.dot(weights) + val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label + + val gradient = data.mul(gradientMultiplier) + val loss = + if (margin > 0) { + math.log(1 + math.exp(0 - margin)) + } else { + math.log(1 + math.exp(margin)) - margin + } + + (gradient, loss) + } +} + +/** + * Compute gradient and loss for a Least-squared loss function. + */ +class SquaredGradient extends Gradient { + override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): + (DoubleMatrix, Double) = { + val diff: Double = data.dot(weights) - label + + val loss = 0.5 * diff * diff + val gradient = data.mul(diff) + + (gradient, loss) + } +} + +/** + * Compute gradient and loss for a Hinge loss function. + * NOTE: This assumes that the labels are {0,1} + */ +class HingeGradient extends Gradient { + override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): + (DoubleMatrix, Double) = { + + val dotProduct = data.dot(weights) + + // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Therefore the gradient is -(2y - 1)*x + val labelScaled = 2 * label - 1.0 + + if (1.0 > labelScaled * dotProduct) { + (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct) + } else { + (DoubleMatrix.zeros(1, weights.length), 0.0) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala new file mode 100644 index 0000000000..b62c9b3340 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.SparkContext._ + +import org.jblas.DoubleMatrix + +import scala.collection.mutable.ArrayBuffer + +/** + * Class used to solve an optimization problem using Gradient Descent. + * @param gradient Gradient function to be used. + * @param updater Updater to be used to update weights after every iteration. + */ +class GradientDescent(var gradient: Gradient, var updater: Updater) extends Optimizer { + + private var stepSize: Double = 1.0 + private var numIterations: Int = 100 + private var regParam: Double = 0.0 + private var miniBatchFraction: Double = 1.0 + + /** + * Set the step size per-iteration of SGD. Default 1.0. + */ + def setStepSize(step: Double): this.type = { + this.stepSize = step + this + } + + /** + * Set fraction of data to be used for each SGD iteration. Default 1.0. + */ + def setMiniBatchFraction(fraction: Double): this.type = { + this.miniBatchFraction = fraction + this + } + + /** + * Set the number of iterations for SGD. Default 100. + */ + def setNumIterations(iters: Int): this.type = { + this.numIterations = iters + this + } + + /** + * Set the regularization parameter used for SGD. Default 0.0. + */ + def setRegParam(regParam: Double): this.type = { + this.regParam = regParam + this + } + + /** + * Set the gradient function to be used for SGD. + */ + def setGradient(gradient: Gradient): this.type = { + this.gradient = gradient + this + } + + + /** + * Set the updater function to be used for SGD. + */ + def setUpdater(updater: Updater): this.type = { + this.updater = updater + this + } + + def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]) + : Array[Double] = { + + val (weights, stochasticLossHistory) = GradientDescent.runMiniBatchSGD( + data, + gradient, + updater, + stepSize, + numIterations, + regParam, + miniBatchFraction, + initialWeights) + weights + } + +} + +// Top-level method to run gradient descent. +object GradientDescent extends Logging { + /** + * Run gradient descent in parallel using mini batches. + * + * @param data - Input data for SGD. RDD of form (label, [feature values]). + * @param gradient - Gradient object that will be used to compute the gradient. + * @param updater - Updater object that will be used to update the model. + * @param stepSize - stepSize to be used during update. + * @param numIterations - number of iterations that SGD should be run. + * @param regParam - regularization parameter + * @param miniBatchFraction - fraction of the input data set that should be used for + * one iteration of SGD. Default value 1.0. + * + * @return A tuple containing two elements. The first element is a column matrix containing + * weights for every feature, and the second element is an array containing the stochastic + * loss computed for every iteration. + */ + def runMiniBatchSGD( + data: RDD[(Double, Array[Double])], + gradient: Gradient, + updater: Updater, + stepSize: Double, + numIterations: Int, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) : (Array[Double], Array[Double]) = { + + val stochasticLossHistory = new ArrayBuffer[Double](numIterations) + + val nexamples: Long = data.count() + val miniBatchSize = nexamples * miniBatchFraction + + // Initialize weights as a column vector + var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) + var regVal = 0.0 + + for (i <- 1 to numIterations) { + val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42+i).map { + case (y, features) => + val featuresCol = new DoubleMatrix(features.length, 1, features:_*) + val (grad, loss) = gradient.compute(featuresCol, y, weights) + (grad, loss) + }.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2)) + + /** + * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * and regVal is the regularization value computed in the previous iteration as well. + */ + stochasticLossHistory.append(lossSum / miniBatchSize + regVal) + val update = updater.compute( + weights, gradientSum.div(miniBatchSize), stepSize, i, regParam) + weights = update._1 + regVal = update._2 + } + + logInfo("GradientDescent finished. Last 10 stochastic losses %s".format( + stochasticLossHistory.takeRight(10).mkString(", "))) + + (weights.toArray, stochasticLossHistory.toArray) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala new file mode 100644 index 0000000000..50059d385d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import org.apache.spark.RDD + +trait Optimizer { + + /** + * Solve the provided convex optimization problem. + */ + def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]): Array[Double] + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala new file mode 100644 index 0000000000..4c51f4f881 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import scala.math._ +import org.jblas.DoubleMatrix + +/** + * Class used to update weights used in Gradient Descent. + */ +abstract class Updater extends Serializable { + /** + * Compute an updated value for weights given the gradient, stepSize, iteration number and + * regularization parameter. Also returns the regularization value computed using the + * *updated* weights. + * + * @param weightsOld - Column matrix of size nx1 where n is the number of features. + * @param gradient - Column matrix of size nx1 where n is the number of features. + * @param stepSize - step size across iterations + * @param iter - Iteration number + * @param regParam - Regularization parameter + * + * @return A tuple of 2 elements. The first element is a column matrix containing updated weights, + * and the second element is the regularization value computed using updated weights. + */ + def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, + regParam: Double): (DoubleMatrix, Double) +} + +/** + * A simple updater that adaptively adjusts the learning rate the + * square root of the number of iterations. Does not perform any regularization. + */ +class SimpleUpdater extends Updater { + override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, + stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + val thisIterStepSize = stepSize / math.sqrt(iter) + val normGradient = gradient.mul(thisIterStepSize) + (weightsOld.sub(normGradient), 0) + } +} + +/** + * Updater that adjusts learning rate and performs L1 regularization. + * + * The corresponding proximal operator used is the soft-thresholding function. + * That is, each weight component is shrunk towards 0 by shrinkageVal. + * + * If w > shrinkageVal, set weight component to w-shrinkageVal. + * If w < -shrinkageVal, set weight component to w+shrinkageVal. + * If -shrinkageVal < w < shrinkageVal, set weight component to 0. + * + * Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal) + */ +class L1Updater extends Updater { + override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, + stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + val thisIterStepSize = stepSize / math.sqrt(iter) + val normGradient = gradient.mul(thisIterStepSize) + // Take gradient step + val newWeights = weightsOld.sub(normGradient) + // Soft thresholding + val shrinkageVal = regParam * thisIterStepSize + (0 until newWeights.length).foreach { i => + val wi = newWeights.get(i) + newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal)) + } + (newWeights, newWeights.norm1 * regParam) + } +} + +/** + * Updater that adjusts the learning rate and performs L2 regularization + */ +class SquaredL2Updater extends Updater { + override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, + stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + val thisIterStepSize = stepSize / math.sqrt(iter) + val normGradient = gradient.mul(thisIterStepSize) + val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0) + (newWeights, pow(newWeights.norm2, 2.0) * regParam) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala new file mode 100644 index 0000000000..218217acfe --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import scala.collection.mutable.{ArrayBuffer, BitSet} +import scala.util.Random +import scala.util.Sorting + +import org.apache.spark.{HashPartitioner, Partitioner, SparkContext, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.KryoRegistrator +import org.apache.spark.SparkContext._ + +import com.esotericsoftware.kryo.Kryo +import org.jblas.{DoubleMatrix, SimpleBlas, Solve} + + +/** + * Out-link information for a user or product block. This includes the original user/product IDs + * of the elements within this block, and the list of destination blocks that each user or + * product will need to send its feature vector to. + */ +private[recommendation] case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[BitSet]) + + +/** + * In-link information for a user (or product) block. This includes the original user/product IDs + * of the elements within this block, as well as an array of indices and ratings that specify + * which user in the block will be rated by which products from each product block (or vice-versa). + * Specifically, if this InLinkBlock is for users, ratingsForBlock(b)(i) will contain two arrays, + * indices and ratings, for the i'th product that will be sent to us by product block b (call this + * P). These arrays represent the users that product P had ratings for (by their index in this + * block), as well as the corresponding rating for each one. We can thus use this information when + * we get product block b's message to update the corresponding users. + */ +private[recommendation] case class InLinkBlock( + elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]]) + + +/** + * A more compact class to represent a rating than Tuple3[Int, Int, Double]. + */ +case class Rating(val user: Int, val product: Int, val rating: Double) + +/** + * Alternating Least Squares matrix factorization. + * + * This is a blocked implementation of the ALS factorization algorithm that groups the two sets + * of factors (referred to as "users" and "products") into blocks and reduces communication by only + * sending one copy of each user vector to each product block on each iteration, and only for the + * product blocks that need that user's feature vector. This is achieved by precomputing some + * information about the ratings matrix to determine the "out-links" of each user (which blocks of + * products it will contribute to) and "in-link" information for each product (which of the feature + * vectors it receives from each user block it will depend on). This allows us to send only an + * array of feature vectors between each user block and product block, and have the product block + * find the users' ratings and update the products based on these messages. + */ +class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double) + extends Serializable +{ + def this() = this(-1, 10, 10, 0.01) + + /** + * Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured + * number of blocks. Default: -1. + */ + def setBlocks(numBlocks: Int): ALS = { + this.numBlocks = numBlocks + this + } + + /** Set the rank of the feature matrices computed (number of features). Default: 10. */ + def setRank(rank: Int): ALS = { + this.rank = rank + this + } + + /** Set the number of iterations to run. Default: 10. */ + def setIterations(iterations: Int): ALS = { + this.iterations = iterations + this + } + + /** Set the regularization parameter, lambda. Default: 0.01. */ + def setLambda(lambda: Double): ALS = { + this.lambda = lambda + this + } + + /** + * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. + * Returns a MatrixFactorizationModel with feature vectors for each user and product. + */ + def run(ratings: RDD[Rating]): MatrixFactorizationModel = { + val numBlocks = if (this.numBlocks == -1) { + math.max(ratings.context.defaultParallelism, ratings.partitions.size / 2) + } else { + this.numBlocks + } + + val partitioner = new HashPartitioner(numBlocks) + + val ratingsByUserBlock = ratings.map{ rating => (rating.user % numBlocks, rating) } + val ratingsByProductBlock = ratings.map{ rating => + (rating.product % numBlocks, Rating(rating.product, rating.user, rating.rating)) + } + + val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock) + val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock) + + // Initialize user and product factors randomly, but use a deterministic seed for each partition + // so that fault recovery works + val seedGen = new Random() + val seed1 = seedGen.nextInt() + val seed2 = seedGen.nextInt() + // Hash an integer to propagate random bits at all positions, similar to java.util.HashTable + def hash(x: Int): Int = { + val r = x ^ (x >>> 20) ^ (x >>> 12) + r ^ (r >>> 7) ^ (r >>> 4) + } + var users = userOutLinks.mapPartitionsWithIndex { (index, itr) => + val rand = new Random(hash(seed1 ^ index)) + itr.map { case (x, y) => + (x, y.elementIds.map(_ => randomFactor(rank, rand))) + } + } + var products = productOutLinks.mapPartitionsWithIndex { (index, itr) => + val rand = new Random(hash(seed2 ^ index)) + itr.map { case (x, y) => + (x, y.elementIds.map(_ => randomFactor(rank, rand))) + } + } + + for (iter <- 0 until iterations) { + // perform ALS update + products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda) + users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda) + } + + // Flatten and cache the two final RDDs to un-block them + val usersOut = users.join(userOutLinks).flatMap { case (b, (factors, outLinkBlock)) => + for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) + } + val productsOut = products.join(productOutLinks).flatMap { case (b, (factors, outLinkBlock)) => + for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) + } + + usersOut.persist() + productsOut.persist() + + new MatrixFactorizationModel(rank, usersOut, productsOut) + } + + /** + * Make the out-links table for a block of the users (or products) dataset given the list of + * (user, product, rating) values for the users in that block (or the opposite for products). + */ + private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating]): OutLinkBlock = { + val userIds = ratings.map(_.user).distinct.sorted + val numUsers = userIds.length + val userIdToPos = userIds.zipWithIndex.toMap + val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks)) + for (r <- ratings) { + shouldSend(userIdToPos(r.user))(r.product % numBlocks) = true + } + OutLinkBlock(userIds, shouldSend) + } + + /** + * Make the in-links table for a block of the users (or products) dataset given a list of + * (user, product, rating) values for the users in that block (or the opposite for products). + */ + private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating]): InLinkBlock = { + val userIds = ratings.map(_.user).distinct.sorted + val numUsers = userIds.length + val userIdToPos = userIds.zipWithIndex.toMap + // Split out our ratings by product block + val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating]) + for (r <- ratings) { + blockRatings(r.product % numBlocks) += r + } + val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks) + for (productBlock <- 0 until numBlocks) { + // Create an array of (product, Seq(Rating)) ratings + val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray + // Sort them by product ID + val ordering = new Ordering[(Int, ArrayBuffer[Rating])] { + def compare(a: (Int, ArrayBuffer[Rating]), b: (Int, ArrayBuffer[Rating])): Int = a._1 - b._1 + } + Sorting.quickSort(groupedRatings)(ordering) + // Translate the user IDs to indices based on userIdToPos + ratingsForBlock(productBlock) = groupedRatings.map { case (p, rs) => + (rs.view.map(r => userIdToPos(r.user)).toArray, rs.view.map(_.rating).toArray) + } + } + InLinkBlock(userIds, ratingsForBlock) + } + + /** + * Make RDDs of InLinkBlocks and OutLinkBlocks given an RDD of (blockId, (u, p, r)) values for + * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid + * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it. + */ + private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)]) + : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = + { + val grouped = ratings.partitionBy(new HashPartitioner(numBlocks)) + val links = grouped.mapPartitionsWithIndex((blockId, elements) => { + val ratings = elements.map{_._2}.toArray + val inLinkBlock = makeInLinkBlock(numBlocks, ratings) + val outLinkBlock = makeOutLinkBlock(numBlocks, ratings) + Iterator.single((blockId, (inLinkBlock, outLinkBlock))) + }, true) + links.persist(StorageLevel.MEMORY_AND_DISK) + (links.mapValues(_._1), links.mapValues(_._2)) + } + + /** + * Make a random factor vector with the given random. + */ + private def randomFactor(rank: Int, rand: Random): Array[Double] = { + Array.fill(rank)(rand.nextDouble) + } + + /** + * Compute the user feature vectors given the current products (or vice-versa). This first joins + * the products with their out-links to generate a set of messages to each destination block + * (specifically, the features for the products that user block cares about), then groups these + * by destination and joins them with the in-link info to figure out how to update each user. + * It returns an RDD of new feature vectors for each user block. + */ + private def updateFeatures( + products: RDD[(Int, Array[Array[Double]])], + productOutLinks: RDD[(Int, OutLinkBlock)], + userInLinks: RDD[(Int, InLinkBlock)], + partitioner: Partitioner, + rank: Int, + lambda: Double) + : RDD[(Int, Array[Array[Double]])] = + { + val numBlocks = products.partitions.size + productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) => + val toSend = Array.fill(numBlocks)(new ArrayBuffer[Array[Double]]) + for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numBlocks) { + if (outLinkBlock.shouldSend(p)(userBlock)) { + toSend(userBlock) += factors(p) + } + } + toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } + }.groupByKey(partitioner) + .join(userInLinks) + .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda) } + } + + /** + * Compute the new feature vectors for a block of the users matrix given the list of factors + * it received from each product and its InLinkBlock. + */ + def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, + rank: Int, lambda: Double) + : Array[Array[Double]] = + { + // Sort the incoming block factor messages by block ID and make them an array + val blockFactors = messages.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] + val numBlocks = blockFactors.length + val numUsers = inLinkBlock.elementIds.length + + // We'll sum up the XtXes using vectors that represent only the lower-triangular part, since + // the matrices are symmetric + val triangleSize = rank * (rank + 1) / 2 + val userXtX = Array.fill(numUsers)(DoubleMatrix.zeros(triangleSize)) + val userXy = Array.fill(numUsers)(DoubleMatrix.zeros(rank)) + + // Some temp variables to avoid memory allocation + val tempXtX = DoubleMatrix.zeros(triangleSize) + val fullXtX = DoubleMatrix.zeros(rank, rank) + + // Compute the XtX and Xy values for each user by adding products it rated in each product block + for (productBlock <- 0 until numBlocks) { + for (p <- 0 until blockFactors(productBlock).length) { + val x = new DoubleMatrix(blockFactors(productBlock)(p)) + fillXtX(x, tempXtX) + val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p) + for (i <- 0 until us.length) { + userXtX(us(i)).addi(tempXtX) + SimpleBlas.axpy(rs(i), x, userXy(us(i))) + } + } + } + + // Solve the least-squares problem for each user and return the new feature vectors + userXtX.zipWithIndex.map{ case (triangularXtX, index) => + // Compute the full XtX matrix from the lower-triangular part we got above + fillFullMatrix(triangularXtX, fullXtX) + // Add regularization + (0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda) + // Solve the resulting matrix, which is symmetric and positive-definite + Solve.solvePositive(fullXtX, userXy(index)).data + } + } + + /** + * Set xtxDest to the lower-triangular part of x transpose * x. For efficiency in summing + * these matrices, we store xtxDest as only rank * (rank+1) / 2 values, namely the values + * at (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), etc in that order. + */ + private def fillXtX(x: DoubleMatrix, xtxDest: DoubleMatrix) { + var i = 0 + var pos = 0 + while (i < x.length) { + var j = 0 + while (j <= i) { + xtxDest.data(pos) = x.data(i) * x.data(j) + pos += 1 + j += 1 + } + i += 1 + } + } + + /** + * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square + * matrix that it represents, storing it into destMatrix. + */ + private def fillFullMatrix(triangularMatrix: DoubleMatrix, destMatrix: DoubleMatrix) { + val rank = destMatrix.rows + var i = 0 + var pos = 0 + while (i < rank) { + var j = 0 + while (j <= i) { + destMatrix.data(i*rank + j) = triangularMatrix.data(pos) + destMatrix.data(j*rank + i) = triangularMatrix.data(pos) + pos += 1 + j += 1 + } + i += 1 + } + } +} + + +/** + * Top-level methods for calling Alternating Least Squares (ALS) matrix factorizaton. + */ +object ALS { + /** + * Train a matrix factorization model given an RDD of ratings given by users to some products, + * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + * product of two lower-rank matrices of a given rank (number of features). To solve for these + * features, we run a given number of iterations of ALS. This is done using a level of + * parallelism given by `blocks`. + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + * @param lambda regularization factor (recommended: 0.01) + * @param blocks level of parallelism to split computation into + */ + def train( + ratings: RDD[Rating], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int) + : MatrixFactorizationModel = + { + new ALS(blocks, rank, iterations, lambda).run(ratings) + } + + /** + * Train a matrix factorization model given an RDD of ratings given by users to some products, + * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + * product of two lower-rank matrices of a given rank (number of features). To solve for these + * features, we run a given number of iterations of ALS. The level of parallelism is determined + * automatically based on the number of partitions in `ratings`. + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + * @param lambda regularization factor (recommended: 0.01) + */ + def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) + : MatrixFactorizationModel = + { + train(ratings, rank, iterations, lambda, -1) + } + + /** + * Train a matrix factorization model given an RDD of ratings given by users to some products, + * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + * product of two lower-rank matrices of a given rank (number of features). To solve for these + * features, we run a given number of iterations of ALS. The level of parallelism is determined + * automatically based on the number of partitions in `ratings`. + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + */ + def train(ratings: RDD[Rating], rank: Int, iterations: Int) + : MatrixFactorizationModel = + { + train(ratings, rank, iterations, 0.01, -1) + } + + private class ALSRegistrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo) { + kryo.register(classOf[Rating]) + } + } + + def main(args: Array[String]) { + if (args.length != 5 && args.length != 6) { + println("Usage: ALS []") + System.exit(1) + } + val (master, ratingsFile, rank, iters, outputDir) = + (args(0), args(1), args(2).toInt, args(3).toInt, args(4)) + val blocks = if (args.length == 6) args(5).toInt else -1 + System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName) + System.setProperty("spark.kryo.referenceTracking", "false") + System.setProperty("spark.kryoserializer.buffer.mb", "8") + System.setProperty("spark.locality.wait", "10000") + val sc = new SparkContext(master, "ALS") + val ratings = sc.textFile(ratingsFile).map { line => + val fields = line.split(',') + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) + } + val model = ALS.train(ratings, rank, iters, 0.01, blocks) + model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } + .saveAsTextFile(outputDir + "/userFeatures") + model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } + .saveAsTextFile(outputDir + "/productFeatures") + println("Final user/product features written to " + outputDir) + System.exit(0) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala new file mode 100644 index 0000000000..ae9fe48aec --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import org.apache.spark.RDD +import org.apache.spark.SparkContext._ + +import org.jblas._ + +/** + * Model representing the result of matrix factorization. + * + * @param rank Rank for the features in this model. + * @param userFeatures RDD of tuples where each tuple represents the userId and + * the features computed for this user. + * @param productFeatures RDD of tuples where each tuple represents the productId + * and the features computed for this product. + */ +class MatrixFactorizationModel( + val rank: Int, + val userFeatures: RDD[(Int, Array[Double])], + val productFeatures: RDD[(Int, Array[Double])]) + extends Serializable +{ + /** Predict the rating of one user for one product. */ + def predict(user: Int, product: Int): Double = { + val userVector = new DoubleMatrix(userFeatures.lookup(user).head) + val productVector = new DoubleMatrix(productFeatures.lookup(product).head) + userVector.dot(productVector) + } + + // TODO: Figure out what good bulk prediction methods would look like. + // Probably want a way to get the top users for a product or vice-versa. +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala new file mode 100644 index 0000000000..06015110ac --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.{Logging, RDD, SparkException} +import org.apache.spark.mllib.optimization._ + +import org.jblas.DoubleMatrix + +/** + * GeneralizedLinearModel (GLM) represents a model trained using + * GeneralizedLinearAlgorithm. GLMs consist of a weight vector and + * an intercept. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: Double) + extends Serializable { + + // Create a column vector that can be used for predictions + private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) + + /** + * Predict the result given a data point and the weights learned. + * + * @param dataMatrix Row vector containing the features for this data point + * @param weightMatrix Column vector containing the weights of the model + * @param intercept Intercept of the model. + */ + def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double): Double + + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(testData: RDD[Array[Double]]): RDD[Double] = { + // A small optimization to avoid serializing the entire model. Only the weightsMatrix + // and intercept is needed. + val localWeights = weightsMatrix + val localIntercept = intercept + + testData.map { x => + val dataMatrix = new DoubleMatrix(1, x.length, x:_*) + predictPoint(dataMatrix, localWeights, localIntercept) + } + } + + /** + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Double prediction from the trained model + */ + def predict(testData: Array[Double]): Double = { + val dataMat = new DoubleMatrix(1, testData.length, testData:_*) + predictPoint(dataMat, weightsMatrix, intercept) + } +} + +/** + * GeneralizedLinearAlgorithm implements methods to train a Genearalized Linear Model (GLM). + * This class should be extended with an Optimizer to create a new GLM. + */ +abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] + extends Logging with Serializable { + + protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List() + + val optimizer: Optimizer + + protected var addIntercept: Boolean = true + + protected var validateData: Boolean = true + + /** + * Create a model given the weights and intercept + */ + protected def createModel(weights: Array[Double], intercept: Double): M + + /** + * Set if the algorithm should add an intercept. Default true. + */ + def setIntercept(addIntercept: Boolean): this.type = { + this.addIntercept = addIntercept + this + } + + /** + * Set if the algorithm should validate data before training. Default true. + */ + def setValidateData(validateData: Boolean): this.type = { + this.validateData = validateData + this + } + + /** + * Run the algorithm with the configured parameters on an input + * RDD of LabeledPoint entries. + */ + def run(input: RDD[LabeledPoint]) : M = { + val nfeatures: Int = input.first().features.length + val initialWeights = Array.fill(nfeatures)(1.0) + run(input, initialWeights) + } + + /** + * Run the algorithm with the configured parameters on an input RDD + * of LabeledPoint entries starting from the initial weights provided. + */ + def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = { + + // Check the data properties before running the optimizer + if (validateData && !validators.forall(func => func(input))) { + throw new SparkException("Input validation failed.") + } + + // Add a extra variable consisting of all 1.0's for the intercept. + val data = if (addIntercept) { + input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*))) + } else { + input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + } + + val initialWeightsWithIntercept = if (addIntercept) { + Array(1.0, initialWeights:_*) + } else { + initialWeights + } + + val weights = optimizer.optimize(data, initialWeightsWithIntercept) + val intercept = weights(0) + val weightsScaled = weights.tail + + val model = createModel(weightsScaled, intercept) + + logInfo("Final model weights " + model.weights.mkString(",")) + logInfo("Final model intercept " + model.intercept) + model + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala new file mode 100644 index 0000000000..63240e24dc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +/** + * Class that represents the features and labels of a data point. + * + * @param label Label for this data point. + * @param features List of features for this data point. + */ +case class LabeledPoint(val label: Double, val features: Array[Double]) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala new file mode 100644 index 0000000000..df3beb1959 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.util.MLUtils + +import org.jblas.DoubleMatrix + +/** + * Regression model trained using Lasso. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class LassoModel( + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) + with RegressionModel with Serializable { + + override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double) = { + dataMatrix.dot(weightMatrix) + intercept + } +} + +/** + * Train a regression model with L1-regularization using Stochastic Gradient Descent. + */ +class LassoWithSGD private ( + var stepSize: Double, + var numIterations: Int, + var regParam: Double, + var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[LassoModel] + with Serializable { + + val gradient = new SquaredGradient() + val updater = new L1Updater() + @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + + // We don't want to penalize the intercept, so set this to false. + setIntercept(false) + + var yMean = 0.0 + var xColMean: DoubleMatrix = _ + var xColSd: DoubleMatrix = _ + + /** + * Construct a Lasso object with default parameters + */ + def this() = this(1.0, 100, 1.0, 1.0) + + def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + val weightsScaled = weightsMat.div(xColSd) + val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) + + new LassoModel(weightsScaled.data, interceptScaled) + } + + override def run( + input: RDD[LabeledPoint], + initialWeights: Array[Double]) + : LassoModel = + { + val nfeatures: Int = input.first.features.length + val nexamples: Long = input.count() + + // To avoid penalizing the intercept, we center and scale the data. + val stats = MLUtils.computeStats(input, nfeatures, nexamples) + yMean = stats._1 + xColMean = stats._2 + xColSd = stats._3 + + val normalizedData = input.map { point => + val yNormalized = point.label - yMean + val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) + val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) + LabeledPoint(yNormalized, featuresNormalized.toArray) + } + + super.run(normalizedData, initialWeights) + } +} + +/** + * Top-level methods for calling Lasso. + */ +object LassoWithSGD { + + /** + * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in + * gradient descent are initialized using the initial weights provided. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * the number of features in the data. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) + : LassoModel = + { + new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, + initialWeights) + } + + /** + * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double) + : LassoModel = + { + new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) + } + + /** + * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param stepSize Step size to be used for each iteration of Gradient Descent. + * @param regParam Regularization parameter. + * @param numIterations Number of iterations of gradient descent to run. + * @return a LassoModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double) + : LassoModel = + { + train(input, numIterations, stepSize, regParam, 1.0) + } + + /** + * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using a step size of 1.0. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @return a LassoModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int) + : LassoModel = + { + train(input, numIterations, 1.0, 1.0, 1.0) + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: Lasso ") + System.exit(1) + } + val sc = new SparkContext(args(0), "Lasso") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala new file mode 100644 index 0000000000..71f968471c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.util.MLUtils + +import org.jblas.DoubleMatrix + +/** + * Regression model trained using LinearRegression. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class LinearRegressionModel( + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) + with RegressionModel with Serializable { + + override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double) = { + dataMatrix.dot(weightMatrix) + intercept + } +} + +/** + * Train a regression model with no regularization using Stochastic Gradient Descent. + */ +class LinearRegressionWithSGD private ( + var stepSize: Double, + var numIterations: Int, + var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[LinearRegressionModel] + with Serializable { + + val gradient = new SquaredGradient() + val updater = new SimpleUpdater() + val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) + .setNumIterations(numIterations) + .setMiniBatchFraction(miniBatchFraction) + + /** + * Construct a LinearRegression object with default parameters + */ + def this() = this(1.0, 100, 1.0) + + def createModel(weights: Array[Double], intercept: Double) = { + new LinearRegressionModel(weights, intercept) + } +} + +/** + * Top-level methods for calling LinearRegression. + */ +object LinearRegressionWithSGD { + + /** + * Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in + * gradient descent are initialized using the initial weights provided. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param miniBatchFraction Fraction of data to be used per iteration. + * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * the number of features in the data. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) + : LinearRegressionModel = + { + new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input, + initialWeights) + } + + /** + * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param miniBatchFraction Fraction of data to be used per iteration. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double) + : LinearRegressionModel = + { + new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input) + } + + /** + * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param stepSize Step size to be used for each iteration of Gradient Descent. + * @param numIterations Number of iterations of gradient descent to run. + * @return a LinearRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double) + : LinearRegressionModel = + { + train(input, numIterations, stepSize, 1.0) + } + + /** + * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using a step size of 1.0. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @return a LinearRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int) + : LinearRegressionModel = + { + train(input, numIterations, 1.0, 1.0) + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: LinearRegression ") + System.exit(1) + } + val sc = new SparkContext(args(0), "LinearRegression") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala new file mode 100644 index 0000000000..8dd325efc0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.RDD + +trait RegressionModel extends Serializable { + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(testData: RDD[Array[Double]]): RDD[Double] + + /** + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Double prediction from the trained model + */ + def predict(testData: Array[Double]): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala new file mode 100644 index 0000000000..228ab9e4e8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.util.MLUtils + +import org.jblas.DoubleMatrix + +/** + * Regression model trained using RidgeRegression. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class RidgeRegressionModel( + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) + with RegressionModel with Serializable { + + override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double) = { + dataMatrix.dot(weightMatrix) + intercept + } +} + +/** + * Train a regression model with L2-regularization using Stochastic Gradient Descent. + */ +class RidgeRegressionWithSGD private ( + var stepSize: Double, + var numIterations: Int, + var regParam: Double, + var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[RidgeRegressionModel] + with Serializable { + + val gradient = new SquaredGradient() + val updater = new SquaredL2Updater() + + @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + + // We don't want to penalize the intercept in RidgeRegression, so set this to false. + setIntercept(false) + + var yMean = 0.0 + var xColMean: DoubleMatrix = _ + var xColSd: DoubleMatrix = _ + + /** + * Construct a RidgeRegression object with default parameters + */ + def this() = this(1.0, 100, 1.0, 1.0) + + def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + val weightsScaled = weightsMat.div(xColSd) + val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) + + new RidgeRegressionModel(weightsScaled.data, interceptScaled) + } + + override def run( + input: RDD[LabeledPoint], + initialWeights: Array[Double]) + : RidgeRegressionModel = + { + val nfeatures: Int = input.first.features.length + val nexamples: Long = input.count() + + // To avoid penalizing the intercept, we center and scale the data. + val stats = MLUtils.computeStats(input, nfeatures, nexamples) + yMean = stats._1 + xColMean = stats._2 + xColSd = stats._3 + + val normalizedData = input.map { point => + val yNormalized = point.label - yMean + val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) + val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) + LabeledPoint(yNormalized, featuresNormalized.toArray) + } + + super.run(normalizedData, initialWeights) + } +} + +/** + * Top-level methods for calling RidgeRegression. + */ +object RidgeRegressionWithSGD { + + /** + * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in + * gradient descent are initialized using the initial weights provided. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * the number of features in the data. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) + : RidgeRegressionModel = + { + new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run( + input, initialWeights) + } + + /** + * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double) + : RidgeRegressionModel = + { + new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) + } + + /** + * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param stepSize Step size to be used for each iteration of Gradient Descent. + * @param regParam Regularization parameter. + * @param numIterations Number of iterations of gradient descent to run. + * @return a RidgeRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double) + : RidgeRegressionModel = + { + train(input, numIterations, stepSize, regParam, 1.0) + } + + /** + * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using a step size of 1.0. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @return a RidgeRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int) + : RidgeRegressionModel = + { + train(input, numIterations, 1.0, 1.0, 1.0) + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: RidgeRegression " + + " ") + System.exit(1) + } + val sc = new SparkContext(args(0), "RidgeRegression") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, + args(3).toDouble) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala new file mode 100644 index 0000000000..7fd4623071 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.apache.spark.{RDD, Logging} +import org.apache.spark.mllib.regression.LabeledPoint + +/** + * A collection of methods used to validate data before applying ML algorithms. + */ +object DataValidators extends Logging { + + /** + * Function to check if labels used for classification are either zero or one. + * + * @param data - input data set that needs to be checked + * + * @return True if labels are all zero or one, false otherwise. + */ + val classificationLabels: RDD[LabeledPoint] => Boolean = { data => + val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() + if (numInvalid != 0) { + logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels") + } + numInvalid == 0 + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala new file mode 100644 index 0000000000..6500d47183 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.util.Random + +import org.apache.spark.{RDD, SparkContext} + +/** + * Generate test data for KMeans. This class first chooses k cluster centers + * from a d-dimensional Gaussian distribution scaled by factor r and then creates a Gaussian + * cluster with scale 1 around each center. + */ + +object KMeansDataGenerator { + + /** + * Generate an RDD containing test data for KMeans. + * + * @param sc SparkContext to use for creating the RDD + * @param numPoints Number of points that will be contained in the RDD + * @param k Number of clusters + * @param d Number of dimensions + * @param r Scaling factor for the distribution of the initial centers + * @param numPartitions Number of partitions of the generated RDD; default 2 + */ + def generateKMeansRDD( + sc: SparkContext, + numPoints: Int, + k: Int, + d: Int, + r: Double, + numPartitions: Int = 2) + : RDD[Array[Double]] = + { + // First, generate some centers + val rand = new Random(42) + val centers = Array.fill(k)(Array.fill(d)(rand.nextGaussian() * r)) + // Then generate points around each center + sc.parallelize(0 until numPoints, numPartitions).map { idx => + val center = centers(idx % k) + val rand2 = new Random(42 + idx) + Array.tabulate(d)(i => center(i) + rand2.nextGaussian()) + } + } + + def main(args: Array[String]) { + if (args.length < 6) { + println("Usage: KMeansGenerator " + + " []") + System.exit(1) + } + + val sparkMaster = args(0) + val outputPath = args(1) + val numPoints = args(2).toInt + val k = args(3).toInt + val d = args(4).toInt + val r = args(5).toDouble + val parts = if (args.length >= 7) args(6).toInt else 2 + + val sc = new SparkContext(sparkMaster, "KMeansDataGenerator") + val data = generateKMeansRDD(sc, numPoints, k, d, r, parts) + data.map(_.mkString(" ")).saveAsTextFile(outputPath) + + System.exit(0) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala new file mode 100644 index 0000000000..4c49d484b4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.jblas.DoubleMatrix + +import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LabeledPoint + +/** + * Generate sample data used for Linear Data. This class generates + * uniformly random values for every feature and adds Gaussian noise with mean `eps` to the + * response variable `Y`. + */ +object LinearDataGenerator { + + /** + * Return a Java List of synthetic data randomly generated according to a multi + * collinear model. + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @return Java List of input. + */ + def generateLinearInputAsList( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int, + eps: Double): java.util.List[LabeledPoint] = { + seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps)) + } + + /** + * + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @return + */ + def generateLinearInput( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int, + eps: Double = 0.1): Seq[LabeledPoint] = { + + val rnd = new Random(seed) + val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val x = Array.fill[Array[Double]](nPoints)( + Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) + val y = x.map { xi => + (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + eps * rnd.nextGaussian() + } + y.zip(x).map(p => LabeledPoint(p._1, p._2)) + } + + /** + * Generate an RDD containing sample data for Linear Regression models - including Ridge, Lasso, + * and uregularized variants. + * + * @param sc SparkContext to be used for generating the RDD. + * @param nexamples Number of examples that will be contained in the RDD. + * @param nfeatures Number of features to generate for each example. + * @param eps Epsilon factor by which examples are scaled. + * @param weights Weights associated with the first weights.length features. + * @param nparts Number of partitions in the RDD. Default value is 2. + * + * @return RDD of LabeledPoint containing sample data. + */ + def generateLinearRDD( + sc: SparkContext, + nexamples: Int, + nfeatures: Int, + eps: Double, + nparts: Int = 2, + intercept: Double = 0.0) : RDD[LabeledPoint] = { + org.jblas.util.Random.seed(42) + // Random values distributed uniformly in [-0.5, 0.5] + val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) + + val data: RDD[LabeledPoint] = sc.parallelize(0 until nparts, nparts).flatMap { p => + val seed = 42 + p + val examplesInPartition = nexamples / nparts + generateLinearInput(intercept, w.toArray, examplesInPartition, seed, eps) + } + data + } + + def main(args: Array[String]) { + if (args.length < 2) { + println("Usage: LinearDataGenerator " + + " [num_examples] [num_features] [num_partitions]") + System.exit(1) + } + + val sparkMaster: String = args(0) + val outputPath: String = args(1) + val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 + val nfeatures: Int = if (args.length > 3) args(3).toInt else 100 + val parts: Int = if (args.length > 4) args(4).toInt else 2 + val eps = 10 + + val sc = new SparkContext(sparkMaster, "LinearDataGenerator") + val data = generateLinearRDD(sc, nexamples, nfeatures, eps, nparts = parts) + + MLUtils.saveLabeledData(data, outputPath) + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala new file mode 100644 index 0000000000..f553298fc5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.util.Random + +import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.mllib.regression.LabeledPoint + +/** + * Generate test data for LogisticRegression. This class chooses positive labels + * with probability `probOne` and scales features for positive examples by `eps`. + */ + +object LogisticRegressionDataGenerator { + + /** + * Generate an RDD containing test data for LogisticRegression. + * + * @param sc SparkContext to use for creating the RDD. + * @param nexamples Number of examples that will be contained in the RDD. + * @param nfeatures Number of features to generate for each example. + * @param eps Epsilon factor by which positive examples are scaled. + * @param nparts Number of partitions of the generated RDD. Default value is 2. + * @param probOne Probability that a label is 1 (and not 0). Default value is 0.5. + */ + def generateLogisticRDD( + sc: SparkContext, + nexamples: Int, + nfeatures: Int, + eps: Double, + nparts: Int = 2, + probOne: Double = 0.5): RDD[LabeledPoint] = { + val data = sc.parallelize(0 until nexamples, nparts).map { idx => + val rnd = new Random(42 + idx) + + val y = if (idx % 2 == 0) 0.0 else 1.0 + val x = Array.fill[Double](nfeatures) { + rnd.nextGaussian() + (y * eps) + } + LabeledPoint(y, x) + } + data + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: LogisticRegressionGenerator " + + " ") + System.exit(1) + } + + val sparkMaster: String = args(0) + val outputPath: String = args(1) + val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 + val nfeatures: Int = if (args.length > 3) args(3).toInt else 2 + val parts: Int = if (args.length > 4) args(4).toInt else 2 + val eps = 3 + + val sc = new SparkContext(sparkMaster, "LogisticRegressionDataGenerator") + val data = generateLogisticRDD(sc, nexamples, nfeatures, eps, parts) + + MLUtils.saveLabeledData(data, outputPath) + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala new file mode 100644 index 0000000000..7eb69ae81c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import scala.util.Random + +import org.jblas.DoubleMatrix + +import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.mllib.util.MLUtils + +/** +* Generate RDD(s) containing data for Matrix Factorization. +* +* This method samples training entries according to the oversampling factor +* 'trainSampFact', which is a multiplicative factor of the number of +* degrees of freedom of the matrix: rank*(m+n-rank). +* +* It optionally samples entries for a testing matrix using +* 'testSampFact', the percentage of the number of training entries +* to use for testing. +* +* This method takes the following inputs: +* sparkMaster (String) The master URL. +* outputPath (String) Directory to save output. +* m (Int) Number of rows in data matrix. +* n (Int) Number of columns in data matrix. +* rank (Int) Underlying rank of data matrix. +* trainSampFact (Double) Oversampling factor. +* noise (Boolean) Whether to add gaussian noise to training data. +* sigma (Double) Standard deviation of added gaussian noise. +* test (Boolean) Whether to create testing RDD. +* testSampFact (Double) Percentage of training data to use as test data. +*/ + +object MFDataGenerator{ + + def main(args: Array[String]) { + if (args.length < 2) { + println("Usage: MFDataGenerator " + + " [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]") + System.exit(1) + } + + val sparkMaster: String = args(0) + val outputPath: String = args(1) + val m: Int = if (args.length > 2) args(2).toInt else 100 + val n: Int = if (args.length > 3) args(3).toInt else 100 + val rank: Int = if (args.length > 4) args(4).toInt else 10 + val trainSampFact: Double = if (args.length > 5) args(5).toDouble else 1.0 + val noise: Boolean = if (args.length > 6) args(6).toBoolean else false + val sigma: Double = if (args.length > 7) args(7).toDouble else 0.1 + val test: Boolean = if (args.length > 8) args(8).toBoolean else false + val testSampFact: Double = if (args.length > 9) args(9).toDouble else 0.1 + + val sc = new SparkContext(sparkMaster, "MFDataGenerator") + + val A = DoubleMatrix.randn(m, rank) + val B = DoubleMatrix.randn(rank, n) + val z = 1 / (scala.math.sqrt(scala.math.sqrt(rank))) + A.mmuli(z) + B.mmuli(z) + val fullData = A.mmul(B) + + val df = rank * (m + n - rank) + val sampSize = scala.math.min(scala.math.round(trainSampFact * df), + scala.math.round(.99 * m * n)).toInt + val rand = new Random() + val mn = m * n + val shuffled = rand.shuffle(1 to mn toIterable) + + val omega = shuffled.slice(0, sampSize) + val ordered = omega.sortWith(_ < _).toArray + val trainData: RDD[(Int, Int, Double)] = sc.parallelize(ordered) + .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) + + // optionally add gaussian noise + if (noise) { + trainData.map(x => (x._1, x._2, x._3 + rand.nextGaussian * sigma)) + } + + trainData.map(x => x._1 + "," + x._2 + "," + x._3).saveAsTextFile(outputPath) + + // optionally generate testing data + if (test) { + val testSampSize = scala.math + .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt + val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) + val testOrdered = testOmega.sortWith(_ < _).toArray + val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) + .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) + testData.map(x => x._1 + "," + x._2 + "," + x._3).saveAsTextFile(outputPath) + } + + sc.stop() + + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala new file mode 100644 index 0000000000..0aeafbe23c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.SparkContext._ + +import org.jblas.DoubleMatrix +import org.apache.spark.mllib.regression.LabeledPoint + +/** + * Helper methods to load, save and pre-process data used in ML Lib. + */ +object MLUtils { + + /** + * Load labeled data from a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.split(',') + val label = parts(0).toDouble + val features = parts(1).trim().split(' ').map(_.toDouble) + LabeledPoint(label, features) + } + } + + /** + * Save labeled data to a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param data An RDD of LabeledPoints containing data to be saved. + * @param dir Directory to save the data. + */ + def saveLabeledData(data: RDD[LabeledPoint], dir: String) { + val dataStr = data.map(x => x.label + "," + x.features.mkString(" ")) + dataStr.saveAsTextFile(dir) + } + + /** + * Utility function to compute mean and standard deviation on a given dataset. + * + * @param data - input data set whose statistics are computed + * @param nfeatures - number of features + * @param nexamples - number of examples in input dataset + * + * @return (yMean, xColMean, xColSd) - Tuple consisting of + * yMean - mean of the labels + * xColMean - Row vector with mean for every column (or feature) of the input data + * xColSd - Row vector standard deviation for every column (or feature) of the input data. + */ + def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long): + (Double, DoubleMatrix, DoubleMatrix) = { + val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples + + // NOTE: We shuffle X by column here to compute column sum and sum of squares. + val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint => + val nCols = labeledPoint.features.length + // Traverse over every column and emit (col, value, value^2) + Iterator.tabulate(nCols) { i => + (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i))) + } + }.reduceByKey { case(x1, x2) => + (x1._1 + x2._1, x1._2 + x2._2) + } + val xColSumsMap = xColSumSq.collectAsMap() + + val xColMean = DoubleMatrix.zeros(nfeatures, 1) + val xColSd = DoubleMatrix.zeros(nfeatures, 1) + + // Compute mean and unbiased variance using column sums + var col = 0 + while (col < nfeatures) { + xColMean.put(col, xColSumsMap(col)._1 / nexamples) + val variance = + (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / (nexamples) + xColSd.put(col, math.sqrt(variance)) + col += 1 + } + + (yMean, xColMean, xColSd) + } + + /** + * Return the squared Euclidean distance between two vectors. + */ + def squaredDistance(v1: Array[Double], v2: Array[Double]): Double = { + if (v1.length != v2.length) { + throw new IllegalArgumentException("Vector sizes don't match") + } + var i = 0 + var sum = 0.0 + while (i < v1.length) { + sum += (v1(i) - v2(i)) * (v1(i) - v2(i)) + i += 1 + } + sum + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala new file mode 100644 index 0000000000..d3f191b05b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -0,0 +1,50 @@ +package org.apache.spark.mllib.util + +import scala.util.Random + +import org.apache.spark.{RDD, SparkContext} + +import org.jblas.DoubleMatrix +import org.apache.spark.mllib.regression.LabeledPoint + +/** + * Generate sample data used for SVM. This class generates uniform random values + * for the features and adds Gaussian noise with weight 0.1 to generate labels. + */ +object SVMDataGenerator { + + def main(args: Array[String]) { + if (args.length < 2) { + println("Usage: SVMGenerator " + + " [num_examples] [num_features] [num_partitions]") + System.exit(1) + } + + val sparkMaster: String = args(0) + val outputPath: String = args(1) + val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 + val nfeatures: Int = if (args.length > 3) args(3).toInt else 2 + val parts: Int = if (args.length > 4) args(4).toInt else 2 + + val sc = new SparkContext(sparkMaster, "SVMGenerator") + + val globalRnd = new Random(94720) + val trueWeights = new DoubleMatrix(1, nfeatures + 1, + Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()):_*) + + val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx => + val rnd = new Random(42 + idx) + + val x = Array.fill[Double](nfeatures) { + rnd.nextDouble() * 2.0 - 1.0 + } + val yD = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1 + val y = if (yD < 0) 0.0 else 1.0 + LabeledPoint(y, x) + } + + MLUtils.saveLabeledData(data, outputPath) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala deleted file mode 100644 index 70fae8c15a..0000000000 --- a/mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala +++ /dev/null @@ -1,21 +0,0 @@ -package spark.mllib.classification - -import spark.RDD - -trait ClassificationModel extends Serializable { - /** - * Predict values for the given data set using the model trained. - * - * @param testData RDD representing data points to be predicted - * @return RDD[Int] where each entry contains the corresponding prediction - */ - def predict(testData: RDD[Array[Double]]): RDD[Double] - - /** - * Predict values for a single data point using the model trained. - * - * @param testData array representing a single data point - * @return Int prediction from the trained model - */ - def predict(testData: Array[Double]): Double -} diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala deleted file mode 100644 index 482e4a6745..0000000000 --- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala +++ /dev/null @@ -1,188 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification - -import scala.math.round - -import spark.{Logging, RDD, SparkContext} -import spark.mllib.optimization._ -import spark.mllib.regression._ -import spark.mllib.util.MLUtils -import spark.mllib.util.DataValidators - -import org.jblas.DoubleMatrix - -/** - * Classification model trained using Logistic Regression. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class LogisticRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with ClassificationModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { - val margin = dataMatrix.mmul(weightMatrix).get(0) + intercept - round(1.0/ (1.0 + math.exp(margin * -1))) - } -} - -/** - * Train a classification model for Logistic Regression using Stochastic Gradient Descent. - * NOTE: Labels used in Logistic Regression should be {0, 1} - */ -class LogisticRegressionWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LogisticRegressionModel] - with Serializable { - - val gradient = new LogisticGradient() - val updater = new SimpleUpdater() - override val optimizer = new GradientDescent(gradient, updater) - .setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) - override val validators = List(DataValidators.classificationLabels) - - /** - * Construct a LogisticRegression object with default parameters - */ - def this() = this(1.0, 100, 0.0, 1.0) - - def createModel(weights: Array[Double], intercept: Double) = { - new LogisticRegressionModel(weights, intercept) - } -} - -/** - * Top-level methods for calling Logistic Regression. - * NOTE: Labels used in Logistic Regression should be {0, 1} - */ -object LogisticRegressionWithSGD { - // NOTE(shivaram): We use multiple train methods instead of default arguments to support - // Java programs. - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) - : LogisticRegressionModel = - { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( - input, initialWeights) - } - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - - * @param miniBatchFraction Fraction of data to be used per iteration. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double) - : LogisticRegressionModel = - { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( - input) - } - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using the specified step size. We use the entire data - * set to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - - * @param numIterations Number of iterations of gradient descent to run. - * @return a LogisticRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double) - : LogisticRegressionModel = - { - train(input, numIterations, stepSize, 1.0) - } - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using a step size of 1.0. We use the entire data set - * to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LogisticRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : LogisticRegressionModel = - { - train(input, numIterations, 1.0, 1.0) - } - - def main(args: Array[String]) { - if (args.length != 4) { - println("Usage: LogisticRegression " + - "") - System.exit(1) - } - val sc = new SparkContext(args(0), "LogisticRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala deleted file mode 100644 index 69393cd7b0..0000000000 --- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification - -import scala.math.signum - -import spark.{Logging, RDD, SparkContext} -import spark.mllib.optimization._ -import spark.mllib.regression._ -import spark.mllib.util.MLUtils -import spark.mllib.util.DataValidators - -import org.jblas.DoubleMatrix - -/** - * Model built using SVM. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class SVMModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with ClassificationModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { - val margin = dataMatrix.dot(weightMatrix) + intercept - if (margin < 0) 0.0 else 1.0 - } -} - -/** - * Train an SVM using Stochastic Gradient Descent. - * NOTE: Labels used in SVM should be {0, 1} - */ -class SVMWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[SVMModel] with Serializable { - - val gradient = new HingeGradient() - val updater = new SquaredL2Updater() - override val optimizer = new GradientDescent(gradient, updater) - .setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) - - override val validators = List(DataValidators.classificationLabels) - - /** - * Construct a SVM object with default parameters - */ - def this() = this(1.0, 100, 1.0, 1.0) - - def createModel(weights: Array[Double], intercept: Double) = { - new SVMModel(weights, intercept) - } -} - -/** - * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1} - */ -object SVMWithSGD { - - /** - * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * NOTE: Labels used in SVM should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) - : SVMModel = - { - new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, - initialWeights) - } - - /** - * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in SVM should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double) - : SVMModel = - { - new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) - } - - /** - * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param regParam Regularization parameter. - * @param numIterations Number of iterations of gradient descent to run. - * @return a SVMModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double) - : SVMModel = - { - train(input, numIterations, stepSize, regParam, 1.0) - } - - /** - * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a SVMModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : SVMModel = - { - train(input, numIterations, 1.0, 1.0, 1.0) - } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: SVM ") - System.exit(1) - } - val sc = new SparkContext(args(0), "SVM") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala deleted file mode 100644 index 97e3d110ae..0000000000 --- a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala +++ /dev/null @@ -1,335 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.clustering - -import scala.collection.mutable.ArrayBuffer -import scala.util.Random - -import spark.{SparkContext, RDD} -import spark.SparkContext._ -import spark.Logging -import spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix - - -/** - * K-means clustering with support for multiple parallel runs and a k-means++ like initialization - * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, - * they are executed together with joint passes over the data for efficiency. - * - * This is an iterative algorithm that will make multiple passes over the data, so any RDDs given - * to it should be cached by the user. - */ -class KMeans private ( - var k: Int, - var maxIterations: Int, - var runs: Int, - var initializationMode: String, - var initializationSteps: Int, - var epsilon: Double) - extends Serializable with Logging -{ - private type ClusterCenters = Array[Array[Double]] - - def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) - - /** Set the number of clusters to create (k). Default: 2. */ - def setK(k: Int): KMeans = { - this.k = k - this - } - - /** Set maximum number of iterations to run. Default: 20. */ - def setMaxIterations(maxIterations: Int): KMeans = { - this.maxIterations = maxIterations - this - } - - /** - * Set the initialization algorithm. This can be either "random" to choose random points as - * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ - * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. - */ - def setInitializationMode(initializationMode: String): KMeans = { - if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) { - throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode) - } - this.initializationMode = initializationMode - this - } - - /** - * Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm - * this many times with random starting conditions (configured by the initialization mode), then - * return the best clustering found over any run. Default: 1. - */ - def setRuns(runs: Int): KMeans = { - if (runs <= 0) { - throw new IllegalArgumentException("Number of runs must be positive") - } - this.runs = runs - this - } - - /** - * Set the number of steps for the k-means|| initialization mode. This is an advanced - * setting -- the default of 5 is almost always enough. Default: 5. - */ - def setInitializationSteps(initializationSteps: Int): KMeans = { - if (initializationSteps <= 0) { - throw new IllegalArgumentException("Number of initialization steps must be positive") - } - this.initializationSteps = initializationSteps - this - } - - /** - * Set the distance threshold within which we've consider centers to have converged. - * If all centers move less than this Euclidean distance, we stop iterating one run. - */ - def setEpsilon(epsilon: Double): KMeans = { - this.epsilon = epsilon - this - } - - /** - * Train a K-means model on the given set of points; `data` should be cached for high - * performance, because this is an iterative algorithm. - */ - def run(data: RDD[Array[Double]]): KMeansModel = { - // TODO: check whether data is persistent; this needs RDD.storageLevel to be publicly readable - - val sc = data.sparkContext - - val centers = if (initializationMode == KMeans.RANDOM) { - initRandom(data) - } else { - initKMeansParallel(data) - } - - val active = Array.fill(runs)(true) - val costs = Array.fill(runs)(0.0) - - var activeRuns = new ArrayBuffer[Int] ++ (0 until runs) - var iteration = 0 - - // Execute iterations of Lloyd's algorithm until all runs have converged - while (iteration < maxIterations && !activeRuns.isEmpty) { - type WeightedPoint = (DoubleMatrix, Long) - def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = { - (p1._1.addi(p2._1), p1._2 + p2._2) - } - - val activeCenters = activeRuns.map(r => centers(r)).toArray - val costAccums = activeRuns.map(_ => sc.accumulator(0.0)) - - // Find the sum and count of points mapping to each center - val totalContribs = data.mapPartitions { points => - val runs = activeCenters.length - val k = activeCenters(0).length - val dims = activeCenters(0)(0).length - - val sums = Array.fill(runs, k)(new DoubleMatrix(dims)) - val counts = Array.fill(runs, k)(0L) - - for (point <- points; (centers, runIndex) <- activeCenters.zipWithIndex) { - val (bestCenter, cost) = KMeans.findClosest(centers, point) - costAccums(runIndex) += cost - sums(runIndex)(bestCenter).addi(new DoubleMatrix(point)) - counts(runIndex)(bestCenter) += 1 - } - - val contribs = for (i <- 0 until runs; j <- 0 until k) yield { - ((i, j), (sums(i)(j), counts(i)(j))) - } - contribs.iterator - }.reduceByKey(mergeContribs).collectAsMap() - - // Update the cluster centers and costs for each active run - for ((run, i) <- activeRuns.zipWithIndex) { - var changed = false - for (j <- 0 until k) { - val (sum, count) = totalContribs((i, j)) - if (count != 0) { - val newCenter = sum.divi(count).data - if (MLUtils.squaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) { - changed = true - } - centers(run)(j) = newCenter - } - } - if (!changed) { - active(run) = false - logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations") - } - costs(run) = costAccums(i).value - } - - activeRuns = activeRuns.filter(active(_)) - iteration += 1 - } - - val bestRun = costs.zipWithIndex.min._2 - new KMeansModel(centers(bestRun)) - } - - /** - * Initialize `runs` sets of cluster centers at random. - */ - private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = { - // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq - Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray) - } - - /** - * Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al. - * (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries - * to find with dissimilar cluster centers by starting with a random center and then doing - * passes where more centers are chosen with probability proportional to their squared distance - * to the current cluster set. It results in a provable approximation to an optimal clustering. - * - * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. - */ - private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = { - // Initialize each run's center to a random point - val seed = new Random().nextInt() - val sample = data.takeSample(true, runs, seed).toSeq - val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r))) - - // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's current centers - for (step <- 0 until initializationSteps) { - val centerArrays = centers.map(_.toArray) - val sumCosts = data.flatMap { point => - for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point)) - }.reduceByKey(_ + _).collectAsMap() - val chosen = data.mapPartitionsWithIndex { (index, points) => - val rand = new Random(seed ^ (step << 16) ^ index) - for { - p <- points - r <- 0 until runs - if rand.nextDouble() < KMeans.pointCost(centerArrays(r), p) * 2 * k / sumCosts(r) - } yield (r, p) - }.collect() - for ((r, p) <- chosen) { - centers(r) += p - } - } - - // Finally, we might have a set of more than k candidate centers for each run; weigh each - // candidate by the number of points in the dataset mapping to it and run a local k-means++ - // on the weighted centers to pick just k of them - val centerArrays = centers.map(_.toArray) - val weightMap = data.flatMap { p => - for (r <- 0 until runs) yield ((r, KMeans.findClosest(centerArrays(r), p)._1), 1.0) - }.reduceByKey(_ + _).collectAsMap() - val finalCenters = (0 until runs).map { r => - val myCenters = centers(r).toArray - val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray - LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) - } - - finalCenters.toArray - } -} - - -/** - * Top-level methods for calling K-means clustering. - */ -object KMeans { - // Initialization mode names - val RANDOM = "random" - val K_MEANS_PARALLEL = "k-means||" - - def train( - data: RDD[Array[Double]], - k: Int, - maxIterations: Int, - runs: Int, - initializationMode: String) - : KMeansModel = - { - new KMeans().setK(k) - .setMaxIterations(maxIterations) - .setRuns(runs) - .setInitializationMode(initializationMode) - .run(data) - } - - def train(data: RDD[Array[Double]], k: Int, maxIterations: Int, runs: Int): KMeansModel = { - train(data, k, maxIterations, runs, K_MEANS_PARALLEL) - } - - def train(data: RDD[Array[Double]], k: Int, maxIterations: Int): KMeansModel = { - train(data, k, maxIterations, 1, K_MEANS_PARALLEL) - } - - /** - * Return the index of the closest point in `centers` to `point`, as well as its distance. - */ - private[mllib] def findClosest(centers: Array[Array[Double]], point: Array[Double]) - : (Int, Double) = - { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - for (i <- 0 until centers.length) { - val distance = MLUtils.squaredDistance(point, centers(i)) - if (distance < bestDistance) { - bestDistance = distance - bestIndex = i - } - } - (bestIndex, bestDistance) - } - - /** - * Return the K-means cost of a given point against the given cluster centers. - */ - private[mllib] def pointCost(centers: Array[Array[Double]], point: Array[Double]): Double = { - var bestDistance = Double.PositiveInfinity - for (i <- 0 until centers.length) { - val distance = MLUtils.squaredDistance(point, centers(i)) - if (distance < bestDistance) { - bestDistance = distance - } - } - bestDistance - } - - def main(args: Array[String]) { - if (args.length < 4) { - println("Usage: KMeans []") - System.exit(1) - } - val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt) - val runs = if (args.length >= 5) args(4).toInt else 1 - val sc = new SparkContext(master, "KMeans") - val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble)).cache() - val model = KMeans.train(data, k, iters, runs) - val cost = model.computeCost(data) - println("Cluster centers:") - for (c <- model.clusterCenters) { - println(" " + c.mkString(" ")) - } - println("Cost: " + cost) - System.exit(0) - } -} diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala deleted file mode 100644 index b8f80e80cd..0000000000 --- a/mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.clustering - -import spark.RDD -import spark.SparkContext._ -import spark.mllib.util.MLUtils - - -/** - * A clustering model for K-means. Each point belongs to the cluster with the closest center. - */ -class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable { - /** Total number of clusters. */ - def k: Int = clusterCenters.length - - /** Return the cluster index that a given point belongs to. */ - def predict(point: Array[Double]): Int = { - KMeans.findClosest(clusterCenters, point)._1 - } - - /** - * Return the K-means cost (sum of squared distances of points to their nearest center) for this - * model on the given data. - */ - def computeCost(data: RDD[Array[Double]]): Double = { - data.map(p => KMeans.pointCost(clusterCenters, p)).sum - } -} diff --git a/mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala deleted file mode 100644 index 89fe7d7e85..0000000000 --- a/mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.clustering - -import scala.util.Random - -import org.jblas.{DoubleMatrix, SimpleBlas} - -/** - * An utility object to run K-means locally. This is private to the ML package because it's used - * in the initialization of KMeans but not meant to be publicly exposed. - */ -private[mllib] object LocalKMeans { - /** - * Run K-means++ on the weighted point set `points`. This first does the K-means++ - * initialization procedure and then roudns of Lloyd's algorithm. - */ - def kMeansPlusPlus( - seed: Int, - points: Array[Array[Double]], - weights: Array[Double], - k: Int, - maxIterations: Int) - : Array[Array[Double]] = - { - val rand = new Random(seed) - val dimensions = points(0).length - val centers = new Array[Array[Double]](k) - - // Initialize centers by sampling using the k-means++ procedure - centers(0) = pickWeighted(rand, points, weights) - for (i <- 1 until k) { - // Pick the next center with a probability proportional to cost under current centers - val curCenters = centers.slice(0, i) - val sum = points.zip(weights).map { case (p, w) => - w * KMeans.pointCost(curCenters, p) - }.sum - val r = rand.nextDouble() * sum - var cumulativeScore = 0.0 - var j = 0 - while (j < points.length && cumulativeScore < r) { - cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j)) - j += 1 - } - centers(i) = points(j-1) - } - - // Run up to maxIterations iterations of Lloyd's algorithm - val oldClosest = Array.fill(points.length)(-1) - var iteration = 0 - var moved = true - while (moved && iteration < maxIterations) { - moved = false - val sums = Array.fill(k)(new DoubleMatrix(dimensions)) - val counts = Array.fill(k)(0.0) - for ((p, i) <- points.zipWithIndex) { - val index = KMeans.findClosest(centers, p)._1 - SimpleBlas.axpy(weights(i), new DoubleMatrix(p), sums(index)) - counts(index) += weights(i) - if (index != oldClosest(i)) { - moved = true - oldClosest(i) = index - } - } - // Update centers - for (i <- 0 until k) { - if (counts(i) == 0.0) { - // Assign center to a random point - centers(i) = points(rand.nextInt(points.length)) - } else { - centers(i) = sums(i).divi(counts(i)).data - } - } - iteration += 1 - } - - centers - } - - private def pickWeighted[T](rand: Random, data: Array[T], weights: Array[Double]): T = { - val r = rand.nextDouble() * weights.sum - var i = 0 - var curWeight = 0.0 - while (i < data.length && curWeight < r) { - curWeight += weights(i) - i += 1 - } - data(i - 1) - } -} diff --git a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala deleted file mode 100644 index 05568f55af..0000000000 --- a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.optimization - -import org.jblas.DoubleMatrix - -/** - * Class used to compute the gradient for a loss function, given a single data point. - */ -abstract class Gradient extends Serializable { - /** - * Compute the gradient and loss given features of a single data point. - * - * @param data - Feature values for one data point. Column matrix of size nx1 - * where n is the number of features. - * @param label - Label for this data item. - * @param weights - Column matrix containing weights for every feature. - * - * @return A tuple of 2 elements. The first element is a column matrix containing the computed - * gradient and the second element is the loss computed at this data point. - * - */ - def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) -} - -/** - * Compute gradient and loss for a logistic loss function. - */ -class LogisticGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - val margin: Double = -1.0 * data.dot(weights) - val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - - val gradient = data.mul(gradientMultiplier) - val loss = - if (margin > 0) { - math.log(1 + math.exp(0 - margin)) - } else { - math.log(1 + math.exp(margin)) - margin - } - - (gradient, loss) - } -} - -/** - * Compute gradient and loss for a Least-squared loss function. - */ -class SquaredGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - val diff: Double = data.dot(weights) - label - - val loss = 0.5 * diff * diff - val gradient = data.mul(diff) - - (gradient, loss) - } -} - -/** - * Compute gradient and loss for a Hinge loss function. - * NOTE: This assumes that the labels are {0,1} - */ -class HingeGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - - val dotProduct = data.dot(weights) - - // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) - // Therefore the gradient is -(2y - 1)*x - val labelScaled = 2 * label - 1.0 - - if (1.0 > labelScaled * dotProduct) { - (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct) - } else { - (DoubleMatrix.zeros(1, weights.length), 0.0) - } - } -} diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala deleted file mode 100644 index 31917df7e8..0000000000 --- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.optimization - -import spark.{Logging, RDD, SparkContext} -import spark.SparkContext._ - -import org.jblas.DoubleMatrix - -import scala.collection.mutable.ArrayBuffer - -/** - * Class used to solve an optimization problem using Gradient Descent. - * @param gradient Gradient function to be used. - * @param updater Updater to be used to update weights after every iteration. - */ -class GradientDescent(var gradient: Gradient, var updater: Updater) extends Optimizer { - - private var stepSize: Double = 1.0 - private var numIterations: Int = 100 - private var regParam: Double = 0.0 - private var miniBatchFraction: Double = 1.0 - - /** - * Set the step size per-iteration of SGD. Default 1.0. - */ - def setStepSize(step: Double): this.type = { - this.stepSize = step - this - } - - /** - * Set fraction of data to be used for each SGD iteration. Default 1.0. - */ - def setMiniBatchFraction(fraction: Double): this.type = { - this.miniBatchFraction = fraction - this - } - - /** - * Set the number of iterations for SGD. Default 100. - */ - def setNumIterations(iters: Int): this.type = { - this.numIterations = iters - this - } - - /** - * Set the regularization parameter used for SGD. Default 0.0. - */ - def setRegParam(regParam: Double): this.type = { - this.regParam = regParam - this - } - - /** - * Set the gradient function to be used for SGD. - */ - def setGradient(gradient: Gradient): this.type = { - this.gradient = gradient - this - } - - - /** - * Set the updater function to be used for SGD. - */ - def setUpdater(updater: Updater): this.type = { - this.updater = updater - this - } - - def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]) - : Array[Double] = { - - val (weights, stochasticLossHistory) = GradientDescent.runMiniBatchSGD( - data, - gradient, - updater, - stepSize, - numIterations, - regParam, - miniBatchFraction, - initialWeights) - weights - } - -} - -// Top-level method to run gradient descent. -object GradientDescent extends Logging { - /** - * Run gradient descent in parallel using mini batches. - * - * @param data - Input data for SGD. RDD of form (label, [feature values]). - * @param gradient - Gradient object that will be used to compute the gradient. - * @param updater - Updater object that will be used to update the model. - * @param stepSize - stepSize to be used during update. - * @param numIterations - number of iterations that SGD should be run. - * @param regParam - regularization parameter - * @param miniBatchFraction - fraction of the input data set that should be used for - * one iteration of SGD. Default value 1.0. - * - * @return A tuple containing two elements. The first element is a column matrix containing - * weights for every feature, and the second element is an array containing the stochastic - * loss computed for every iteration. - */ - def runMiniBatchSGD( - data: RDD[(Double, Array[Double])], - gradient: Gradient, - updater: Updater, - stepSize: Double, - numIterations: Int, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) : (Array[Double], Array[Double]) = { - - val stochasticLossHistory = new ArrayBuffer[Double](numIterations) - - val nexamples: Long = data.count() - val miniBatchSize = nexamples * miniBatchFraction - - // Initialize weights as a column vector - var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) - var regVal = 0.0 - - for (i <- 1 to numIterations) { - val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42+i).map { - case (y, features) => - val featuresCol = new DoubleMatrix(features.length, 1, features:_*) - val (grad, loss) = gradient.compute(featuresCol, y, weights) - (grad, loss) - }.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2)) - - /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. - */ - stochasticLossHistory.append(lossSum / miniBatchSize + regVal) - val update = updater.compute( - weights, gradientSum.div(miniBatchSize), stepSize, i, regParam) - weights = update._1 - regVal = update._2 - } - - logInfo("GradientDescent finished. Last 10 stochastic losses %s".format( - stochasticLossHistory.takeRight(10).mkString(", "))) - - (weights.toArray, stochasticLossHistory.toArray) - } -} diff --git a/mllib/src/main/scala/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/spark/mllib/optimization/Optimizer.scala deleted file mode 100644 index 76a519c338..0000000000 --- a/mllib/src/main/scala/spark/mllib/optimization/Optimizer.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.optimization - -import spark.RDD - -trait Optimizer { - - /** - * Solve the provided convex optimization problem. - */ - def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]): Array[Double] - -} diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala deleted file mode 100644 index db67d6b0bc..0000000000 --- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.optimization - -import scala.math._ -import org.jblas.DoubleMatrix - -/** - * Class used to update weights used in Gradient Descent. - */ -abstract class Updater extends Serializable { - /** - * Compute an updated value for weights given the gradient, stepSize, iteration number and - * regularization parameter. Also returns the regularization value computed using the - * *updated* weights. - * - * @param weightsOld - Column matrix of size nx1 where n is the number of features. - * @param gradient - Column matrix of size nx1 where n is the number of features. - * @param stepSize - step size across iterations - * @param iter - Iteration number - * @param regParam - Regularization parameter - * - * @return A tuple of 2 elements. The first element is a column matrix containing updated weights, - * and the second element is the regularization value computed using updated weights. - */ - def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, - regParam: Double): (DoubleMatrix, Double) -} - -/** - * A simple updater that adaptively adjusts the learning rate the - * square root of the number of iterations. Does not perform any regularization. - */ -class SimpleUpdater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val thisIterStepSize = stepSize / math.sqrt(iter) - val normGradient = gradient.mul(thisIterStepSize) - (weightsOld.sub(normGradient), 0) - } -} - -/** - * Updater that adjusts learning rate and performs L1 regularization. - * - * The corresponding proximal operator used is the soft-thresholding function. - * That is, each weight component is shrunk towards 0 by shrinkageVal. - * - * If w > shrinkageVal, set weight component to w-shrinkageVal. - * If w < -shrinkageVal, set weight component to w+shrinkageVal. - * If -shrinkageVal < w < shrinkageVal, set weight component to 0. - * - * Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal) - */ -class L1Updater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val thisIterStepSize = stepSize / math.sqrt(iter) - val normGradient = gradient.mul(thisIterStepSize) - // Take gradient step - val newWeights = weightsOld.sub(normGradient) - // Soft thresholding - val shrinkageVal = regParam * thisIterStepSize - (0 until newWeights.length).foreach { i => - val wi = newWeights.get(i) - newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal)) - } - (newWeights, newWeights.norm1 * regParam) - } -} - -/** - * Updater that adjusts the learning rate and performs L2 regularization - */ -class SquaredL2Updater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val thisIterStepSize = stepSize / math.sqrt(iter) - val normGradient = gradient.mul(thisIterStepSize) - val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0) - (newWeights, pow(newWeights.norm2, 2.0) * regParam) - } -} - diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala deleted file mode 100644 index dbfbf59975..0000000000 --- a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala +++ /dev/null @@ -1,453 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.recommendation - -import scala.collection.mutable.{ArrayBuffer, BitSet} -import scala.util.Random -import scala.util.Sorting - -import spark.{HashPartitioner, Partitioner, SparkContext, RDD} -import spark.storage.StorageLevel -import spark.KryoRegistrator -import spark.SparkContext._ - -import com.esotericsoftware.kryo.Kryo -import org.jblas.{DoubleMatrix, SimpleBlas, Solve} - - -/** - * Out-link information for a user or product block. This includes the original user/product IDs - * of the elements within this block, and the list of destination blocks that each user or - * product will need to send its feature vector to. - */ -private[recommendation] case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[BitSet]) - - -/** - * In-link information for a user (or product) block. This includes the original user/product IDs - * of the elements within this block, as well as an array of indices and ratings that specify - * which user in the block will be rated by which products from each product block (or vice-versa). - * Specifically, if this InLinkBlock is for users, ratingsForBlock(b)(i) will contain two arrays, - * indices and ratings, for the i'th product that will be sent to us by product block b (call this - * P). These arrays represent the users that product P had ratings for (by their index in this - * block), as well as the corresponding rating for each one. We can thus use this information when - * we get product block b's message to update the corresponding users. - */ -private[recommendation] case class InLinkBlock( - elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]]) - - -/** - * A more compact class to represent a rating than Tuple3[Int, Int, Double]. - */ -case class Rating(val user: Int, val product: Int, val rating: Double) - -/** - * Alternating Least Squares matrix factorization. - * - * This is a blocked implementation of the ALS factorization algorithm that groups the two sets - * of factors (referred to as "users" and "products") into blocks and reduces communication by only - * sending one copy of each user vector to each product block on each iteration, and only for the - * product blocks that need that user's feature vector. This is achieved by precomputing some - * information about the ratings matrix to determine the "out-links" of each user (which blocks of - * products it will contribute to) and "in-link" information for each product (which of the feature - * vectors it receives from each user block it will depend on). This allows us to send only an - * array of feature vectors between each user block and product block, and have the product block - * find the users' ratings and update the products based on these messages. - */ -class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double) - extends Serializable -{ - def this() = this(-1, 10, 10, 0.01) - - /** - * Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured - * number of blocks. Default: -1. - */ - def setBlocks(numBlocks: Int): ALS = { - this.numBlocks = numBlocks - this - } - - /** Set the rank of the feature matrices computed (number of features). Default: 10. */ - def setRank(rank: Int): ALS = { - this.rank = rank - this - } - - /** Set the number of iterations to run. Default: 10. */ - def setIterations(iterations: Int): ALS = { - this.iterations = iterations - this - } - - /** Set the regularization parameter, lambda. Default: 0.01. */ - def setLambda(lambda: Double): ALS = { - this.lambda = lambda - this - } - - /** - * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. - * Returns a MatrixFactorizationModel with feature vectors for each user and product. - */ - def run(ratings: RDD[Rating]): MatrixFactorizationModel = { - val numBlocks = if (this.numBlocks == -1) { - math.max(ratings.context.defaultParallelism, ratings.partitions.size / 2) - } else { - this.numBlocks - } - - val partitioner = new HashPartitioner(numBlocks) - - val ratingsByUserBlock = ratings.map{ rating => (rating.user % numBlocks, rating) } - val ratingsByProductBlock = ratings.map{ rating => - (rating.product % numBlocks, Rating(rating.product, rating.user, rating.rating)) - } - - val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock) - val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock) - - // Initialize user and product factors randomly, but use a deterministic seed for each partition - // so that fault recovery works - val seedGen = new Random() - val seed1 = seedGen.nextInt() - val seed2 = seedGen.nextInt() - // Hash an integer to propagate random bits at all positions, similar to java.util.HashTable - def hash(x: Int): Int = { - val r = x ^ (x >>> 20) ^ (x >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) - } - var users = userOutLinks.mapPartitionsWithIndex { (index, itr) => - val rand = new Random(hash(seed1 ^ index)) - itr.map { case (x, y) => - (x, y.elementIds.map(_ => randomFactor(rank, rand))) - } - } - var products = productOutLinks.mapPartitionsWithIndex { (index, itr) => - val rand = new Random(hash(seed2 ^ index)) - itr.map { case (x, y) => - (x, y.elementIds.map(_ => randomFactor(rank, rand))) - } - } - - for (iter <- 0 until iterations) { - // perform ALS update - products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda) - users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda) - } - - // Flatten and cache the two final RDDs to un-block them - val usersOut = users.join(userOutLinks).flatMap { case (b, (factors, outLinkBlock)) => - for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) - } - val productsOut = products.join(productOutLinks).flatMap { case (b, (factors, outLinkBlock)) => - for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) - } - - usersOut.persist() - productsOut.persist() - - new MatrixFactorizationModel(rank, usersOut, productsOut) - } - - /** - * Make the out-links table for a block of the users (or products) dataset given the list of - * (user, product, rating) values for the users in that block (or the opposite for products). - */ - private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating]): OutLinkBlock = { - val userIds = ratings.map(_.user).distinct.sorted - val numUsers = userIds.length - val userIdToPos = userIds.zipWithIndex.toMap - val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks)) - for (r <- ratings) { - shouldSend(userIdToPos(r.user))(r.product % numBlocks) = true - } - OutLinkBlock(userIds, shouldSend) - } - - /** - * Make the in-links table for a block of the users (or products) dataset given a list of - * (user, product, rating) values for the users in that block (or the opposite for products). - */ - private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating]): InLinkBlock = { - val userIds = ratings.map(_.user).distinct.sorted - val numUsers = userIds.length - val userIdToPos = userIds.zipWithIndex.toMap - // Split out our ratings by product block - val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating]) - for (r <- ratings) { - blockRatings(r.product % numBlocks) += r - } - val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks) - for (productBlock <- 0 until numBlocks) { - // Create an array of (product, Seq(Rating)) ratings - val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray - // Sort them by product ID - val ordering = new Ordering[(Int, ArrayBuffer[Rating])] { - def compare(a: (Int, ArrayBuffer[Rating]), b: (Int, ArrayBuffer[Rating])): Int = a._1 - b._1 - } - Sorting.quickSort(groupedRatings)(ordering) - // Translate the user IDs to indices based on userIdToPos - ratingsForBlock(productBlock) = groupedRatings.map { case (p, rs) => - (rs.view.map(r => userIdToPos(r.user)).toArray, rs.view.map(_.rating).toArray) - } - } - InLinkBlock(userIds, ratingsForBlock) - } - - /** - * Make RDDs of InLinkBlocks and OutLinkBlocks given an RDD of (blockId, (u, p, r)) values for - * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid - * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it. - */ - private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)]) - : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = - { - val grouped = ratings.partitionBy(new HashPartitioner(numBlocks)) - val links = grouped.mapPartitionsWithIndex((blockId, elements) => { - val ratings = elements.map{_._2}.toArray - val inLinkBlock = makeInLinkBlock(numBlocks, ratings) - val outLinkBlock = makeOutLinkBlock(numBlocks, ratings) - Iterator.single((blockId, (inLinkBlock, outLinkBlock))) - }, true) - links.persist(StorageLevel.MEMORY_AND_DISK) - (links.mapValues(_._1), links.mapValues(_._2)) - } - - /** - * Make a random factor vector with the given random. - */ - private def randomFactor(rank: Int, rand: Random): Array[Double] = { - Array.fill(rank)(rand.nextDouble) - } - - /** - * Compute the user feature vectors given the current products (or vice-versa). This first joins - * the products with their out-links to generate a set of messages to each destination block - * (specifically, the features for the products that user block cares about), then groups these - * by destination and joins them with the in-link info to figure out how to update each user. - * It returns an RDD of new feature vectors for each user block. - */ - private def updateFeatures( - products: RDD[(Int, Array[Array[Double]])], - productOutLinks: RDD[(Int, OutLinkBlock)], - userInLinks: RDD[(Int, InLinkBlock)], - partitioner: Partitioner, - rank: Int, - lambda: Double) - : RDD[(Int, Array[Array[Double]])] = - { - val numBlocks = products.partitions.size - productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) => - val toSend = Array.fill(numBlocks)(new ArrayBuffer[Array[Double]]) - for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numBlocks) { - if (outLinkBlock.shouldSend(p)(userBlock)) { - toSend(userBlock) += factors(p) - } - } - toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } - }.groupByKey(partitioner) - .join(userInLinks) - .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda) } - } - - /** - * Compute the new feature vectors for a block of the users matrix given the list of factors - * it received from each product and its InLinkBlock. - */ - def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, - rank: Int, lambda: Double) - : Array[Array[Double]] = - { - // Sort the incoming block factor messages by block ID and make them an array - val blockFactors = messages.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] - val numBlocks = blockFactors.length - val numUsers = inLinkBlock.elementIds.length - - // We'll sum up the XtXes using vectors that represent only the lower-triangular part, since - // the matrices are symmetric - val triangleSize = rank * (rank + 1) / 2 - val userXtX = Array.fill(numUsers)(DoubleMatrix.zeros(triangleSize)) - val userXy = Array.fill(numUsers)(DoubleMatrix.zeros(rank)) - - // Some temp variables to avoid memory allocation - val tempXtX = DoubleMatrix.zeros(triangleSize) - val fullXtX = DoubleMatrix.zeros(rank, rank) - - // Compute the XtX and Xy values for each user by adding products it rated in each product block - for (productBlock <- 0 until numBlocks) { - for (p <- 0 until blockFactors(productBlock).length) { - val x = new DoubleMatrix(blockFactors(productBlock)(p)) - fillXtX(x, tempXtX) - val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p) - for (i <- 0 until us.length) { - userXtX(us(i)).addi(tempXtX) - SimpleBlas.axpy(rs(i), x, userXy(us(i))) - } - } - } - - // Solve the least-squares problem for each user and return the new feature vectors - userXtX.zipWithIndex.map{ case (triangularXtX, index) => - // Compute the full XtX matrix from the lower-triangular part we got above - fillFullMatrix(triangularXtX, fullXtX) - // Add regularization - (0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda) - // Solve the resulting matrix, which is symmetric and positive-definite - Solve.solvePositive(fullXtX, userXy(index)).data - } - } - - /** - * Set xtxDest to the lower-triangular part of x transpose * x. For efficiency in summing - * these matrices, we store xtxDest as only rank * (rank+1) / 2 values, namely the values - * at (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), etc in that order. - */ - private def fillXtX(x: DoubleMatrix, xtxDest: DoubleMatrix) { - var i = 0 - var pos = 0 - while (i < x.length) { - var j = 0 - while (j <= i) { - xtxDest.data(pos) = x.data(i) * x.data(j) - pos += 1 - j += 1 - } - i += 1 - } - } - - /** - * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square - * matrix that it represents, storing it into destMatrix. - */ - private def fillFullMatrix(triangularMatrix: DoubleMatrix, destMatrix: DoubleMatrix) { - val rank = destMatrix.rows - var i = 0 - var pos = 0 - while (i < rank) { - var j = 0 - while (j <= i) { - destMatrix.data(i*rank + j) = triangularMatrix.data(pos) - destMatrix.data(j*rank + i) = triangularMatrix.data(pos) - pos += 1 - j += 1 - } - i += 1 - } - } -} - - -/** - * Top-level methods for calling Alternating Least Squares (ALS) matrix factorizaton. - */ -object ALS { - /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. This is done using a level of - * parallelism given by `blocks`. - * - * @param ratings RDD of (userID, productID, rating) pairs - * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) - * @param blocks level of parallelism to split computation into - */ - def train( - ratings: RDD[Rating], - rank: Int, - iterations: Int, - lambda: Double, - blocks: Int) - : MatrixFactorizationModel = - { - new ALS(blocks, rank, iterations, lambda).run(ratings) - } - - /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. The level of parallelism is determined - * automatically based on the number of partitions in `ratings`. - * - * @param ratings RDD of (userID, productID, rating) pairs - * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) - */ - def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) - : MatrixFactorizationModel = - { - train(ratings, rank, iterations, lambda, -1) - } - - /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. The level of parallelism is determined - * automatically based on the number of partitions in `ratings`. - * - * @param ratings RDD of (userID, productID, rating) pairs - * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - */ - def train(ratings: RDD[Rating], rank: Int, iterations: Int) - : MatrixFactorizationModel = - { - train(ratings, rank, iterations, 0.01, -1) - } - - private class ALSRegistrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[Rating]) - } - } - - def main(args: Array[String]) { - if (args.length != 5 && args.length != 6) { - println("Usage: ALS []") - System.exit(1) - } - val (master, ratingsFile, rank, iters, outputDir) = - (args(0), args(1), args(2).toInt, args(3).toInt, args(4)) - val blocks = if (args.length == 6) args(5).toInt else -1 - System.setProperty("spark.serializer", "spark.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName) - System.setProperty("spark.kryo.referenceTracking", "false") - System.setProperty("spark.kryoserializer.buffer.mb", "8") - System.setProperty("spark.locality.wait", "10000") - val sc = new SparkContext(master, "ALS") - val ratings = sc.textFile(ratingsFile).map { line => - val fields = line.split(',') - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) - } - val model = ALS.train(ratings, rank, iters, 0.01, blocks) - model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } - .saveAsTextFile(outputDir + "/userFeatures") - model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } - .saveAsTextFile(outputDir + "/productFeatures") - println("Final user/product features written to " + outputDir) - System.exit(0) - } -} diff --git a/mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala deleted file mode 100644 index 5e21717da5..0000000000 --- a/mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.recommendation - -import spark.RDD -import spark.SparkContext._ - -import org.jblas._ - -/** - * Model representing the result of matrix factorization. - * - * @param rank Rank for the features in this model. - * @param userFeatures RDD of tuples where each tuple represents the userId and - * the features computed for this user. - * @param productFeatures RDD of tuples where each tuple represents the productId - * and the features computed for this product. - */ -class MatrixFactorizationModel( - val rank: Int, - val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) - extends Serializable -{ - /** Predict the rating of one user for one product. */ - def predict(user: Int, product: Int): Double = { - val userVector = new DoubleMatrix(userFeatures.lookup(user).head) - val productVector = new DoubleMatrix(productFeatures.lookup(product).head) - userVector.dot(productVector) - } - - // TODO: Figure out what good bulk prediction methods would look like. - // Probably want a way to get the top users for a product or vice-versa. -} diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala deleted file mode 100644 index d164d415d6..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import spark.{Logging, RDD, SparkException} -import spark.mllib.optimization._ - -import org.jblas.DoubleMatrix - -/** - * GeneralizedLinearModel (GLM) represents a model trained using - * GeneralizedLinearAlgorithm. GLMs consist of a weight vector and - * an intercept. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: Double) - extends Serializable { - - // Create a column vector that can be used for predictions - private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) - - /** - * Predict the result given a data point and the weights learned. - * - * @param dataMatrix Row vector containing the features for this data point - * @param weightMatrix Column vector containing the weights of the model - * @param intercept Intercept of the model. - */ - def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double): Double - - /** - * Predict values for the given data set using the model trained. - * - * @param testData RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(testData: spark.RDD[Array[Double]]): RDD[Double] = { - // A small optimization to avoid serializing the entire model. Only the weightsMatrix - // and intercept is needed. - val localWeights = weightsMatrix - val localIntercept = intercept - - testData.map { x => - val dataMatrix = new DoubleMatrix(1, x.length, x:_*) - predictPoint(dataMatrix, localWeights, localIntercept) - } - } - - /** - * Predict values for a single data point using the model trained. - * - * @param testData array representing a single data point - * @return Double prediction from the trained model - */ - def predict(testData: Array[Double]): Double = { - val dataMat = new DoubleMatrix(1, testData.length, testData:_*) - predictPoint(dataMat, weightsMatrix, intercept) - } -} - -/** - * GeneralizedLinearAlgorithm implements methods to train a Genearalized Linear Model (GLM). - * This class should be extended with an Optimizer to create a new GLM. - */ -abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] - extends Logging with Serializable { - - protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List() - - val optimizer: Optimizer - - protected var addIntercept: Boolean = true - - protected var validateData: Boolean = true - - /** - * Create a model given the weights and intercept - */ - protected def createModel(weights: Array[Double], intercept: Double): M - - /** - * Set if the algorithm should add an intercept. Default true. - */ - def setIntercept(addIntercept: Boolean): this.type = { - this.addIntercept = addIntercept - this - } - - /** - * Set if the algorithm should validate data before training. Default true. - */ - def setValidateData(validateData: Boolean): this.type = { - this.validateData = validateData - this - } - - /** - * Run the algorithm with the configured parameters on an input - * RDD of LabeledPoint entries. - */ - def run(input: RDD[LabeledPoint]) : M = { - val nfeatures: Int = input.first().features.length - val initialWeights = Array.fill(nfeatures)(1.0) - run(input, initialWeights) - } - - /** - * Run the algorithm with the configured parameters on an input RDD - * of LabeledPoint entries starting from the initial weights provided. - */ - def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = { - - // Check the data properties before running the optimizer - if (validateData && !validators.forall(func => func(input))) { - throw new SparkException("Input validation failed.") - } - - // Add a extra variable consisting of all 1.0's for the intercept. - val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*))) - } else { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) - } - - val initialWeightsWithIntercept = if (addIntercept) { - Array(1.0, initialWeights:_*) - } else { - initialWeights - } - - val weights = optimizer.optimize(data, initialWeightsWithIntercept) - val intercept = weights(0) - val weightsScaled = weights.tail - - val model = createModel(weightsScaled, intercept) - - logInfo("Final model weights " + model.weights.mkString(",")) - logInfo("Final model intercept " + model.intercept) - model - } -} diff --git a/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala deleted file mode 100644 index 3de60482c5..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -/** - * Class that represents the features and labels of a data point. - * - * @param label Label for this data point. - * @param features List of features for this data point. - */ -case class LabeledPoint(val label: Double, val features: Array[Double]) diff --git a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala deleted file mode 100644 index 0f33456ef4..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import spark.{Logging, RDD, SparkContext} -import spark.mllib.optimization._ -import spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix - -/** - * Regression model trained using Lasso. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class LassoModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { - dataMatrix.dot(weightMatrix) + intercept - } -} - -/** - * Train a regression model with L1-regularization using Stochastic Gradient Descent. - */ -class LassoWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LassoModel] - with Serializable { - - val gradient = new SquaredGradient() - val updater = new L1Updater() - @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) - - // We don't want to penalize the intercept, so set this to false. - setIntercept(false) - - var yMean = 0.0 - var xColMean: DoubleMatrix = _ - var xColSd: DoubleMatrix = _ - - /** - * Construct a Lasso object with default parameters - */ - def this() = this(1.0, 100, 1.0, 1.0) - - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) - val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) - - new LassoModel(weightsScaled.data, interceptScaled) - } - - override def run( - input: RDD[LabeledPoint], - initialWeights: Array[Double]) - : LassoModel = - { - val nfeatures: Int = input.first.features.length - val nexamples: Long = input.count() - - // To avoid penalizing the intercept, we center and scale the data. - val stats = MLUtils.computeStats(input, nfeatures, nexamples) - yMean = stats._1 - xColMean = stats._2 - xColSd = stats._3 - - val normalizedData = input.map { point => - val yNormalized = point.label - yMean - val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) - val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) - LabeledPoint(yNormalized, featuresNormalized.toArray) - } - - super.run(normalizedData, initialWeights) - } -} - -/** - * Top-level methods for calling Lasso. - */ -object LassoWithSGD { - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) - : LassoModel = - { - new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, - initialWeights) - } - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double) - : LassoModel = - { - new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) - } - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param regParam Regularization parameter. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LassoModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double) - : LassoModel = - { - train(input, numIterations, stepSize, regParam, 1.0) - } - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LassoModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : LassoModel = - { - train(input, numIterations, 1.0, 1.0, 1.0) - } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: Lasso ") - System.exit(1) - } - val sc = new SparkContext(args(0), "Lasso") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala deleted file mode 100644 index 885ff5a30d..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import spark.{Logging, RDD, SparkContext} -import spark.mllib.optimization._ -import spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix - -/** - * Regression model trained using LinearRegression. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class LinearRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { - dataMatrix.dot(weightMatrix) + intercept - } -} - -/** - * Train a regression model with no regularization using Stochastic Gradient Descent. - */ -class LinearRegressionWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LinearRegressionModel] - with Serializable { - - val gradient = new SquaredGradient() - val updater = new SimpleUpdater() - val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) - .setNumIterations(numIterations) - .setMiniBatchFraction(miniBatchFraction) - - /** - * Construct a LinearRegression object with default parameters - */ - def this() = this(1.0, 100, 1.0) - - def createModel(weights: Array[Double], intercept: Double) = { - new LinearRegressionModel(weights, intercept) - } -} - -/** - * Top-level methods for calling LinearRegression. - */ -object LinearRegressionWithSGD { - - /** - * Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) - : LinearRegressionModel = - { - new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input, - initialWeights) - } - - /** - * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double) - : LinearRegressionModel = - { - new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input) - } - - /** - * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LinearRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double) - : LinearRegressionModel = - { - train(input, numIterations, stepSize, 1.0) - } - - /** - * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LinearRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : LinearRegressionModel = - { - train(input, numIterations, 1.0, 1.0) - } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: LinearRegression ") - System.exit(1) - } - val sc = new SparkContext(args(0), "LinearRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/spark/mllib/regression/RegressionModel.scala deleted file mode 100644 index b845ba1a89..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/RegressionModel.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import spark.RDD - -trait RegressionModel extends Serializable { - /** - * Predict values for the given data set using the model trained. - * - * @param testData RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(testData: RDD[Array[Double]]): RDD[Double] - - /** - * Predict values for a single data point using the model trained. - * - * @param testData array representing a single data point - * @return Double prediction from the trained model - */ - def predict(testData: Array[Double]): Double -} diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala deleted file mode 100644 index cb1303dd99..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import spark.{Logging, RDD, SparkContext} -import spark.mllib.optimization._ -import spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix - -/** - * Regression model trained using RidgeRegression. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class RidgeRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { - dataMatrix.dot(weightMatrix) + intercept - } -} - -/** - * Train a regression model with L2-regularization using Stochastic Gradient Descent. - */ -class RidgeRegressionWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[RidgeRegressionModel] - with Serializable { - - val gradient = new SquaredGradient() - val updater = new SquaredL2Updater() - - @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) - - // We don't want to penalize the intercept in RidgeRegression, so set this to false. - setIntercept(false) - - var yMean = 0.0 - var xColMean: DoubleMatrix = _ - var xColSd: DoubleMatrix = _ - - /** - * Construct a RidgeRegression object with default parameters - */ - def this() = this(1.0, 100, 1.0, 1.0) - - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) - val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) - - new RidgeRegressionModel(weightsScaled.data, interceptScaled) - } - - override def run( - input: RDD[LabeledPoint], - initialWeights: Array[Double]) - : RidgeRegressionModel = - { - val nfeatures: Int = input.first.features.length - val nexamples: Long = input.count() - - // To avoid penalizing the intercept, we center and scale the data. - val stats = MLUtils.computeStats(input, nfeatures, nexamples) - yMean = stats._1 - xColMean = stats._2 - xColSd = stats._3 - - val normalizedData = input.map { point => - val yNormalized = point.label - yMean - val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) - val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) - LabeledPoint(yNormalized, featuresNormalized.toArray) - } - - super.run(normalizedData, initialWeights) - } -} - -/** - * Top-level methods for calling RidgeRegression. - */ -object RidgeRegressionWithSGD { - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) - : RidgeRegressionModel = - { - new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run( - input, initialWeights) - } - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double) - : RidgeRegressionModel = - { - new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) - } - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param regParam Regularization parameter. - * @param numIterations Number of iterations of gradient descent to run. - * @return a RidgeRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double) - : RidgeRegressionModel = - { - train(input, numIterations, stepSize, regParam, 1.0) - } - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a RidgeRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : RidgeRegressionModel = - { - train(input, numIterations, 1.0, 1.0, 1.0) - } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: RidgeRegression " + - " ") - System.exit(1) - } - val sc = new SparkContext(args(0), "RidgeRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, - args(3).toDouble) - - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/spark/mllib/util/DataValidators.scala deleted file mode 100644 index 57553accf1..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/DataValidators.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.util - -import spark.{RDD, Logging} -import spark.mllib.regression.LabeledPoint - -/** - * A collection of methods used to validate data before applying ML algorithms. - */ -object DataValidators extends Logging { - - /** - * Function to check if labels used for classification are either zero or one. - * - * @param data - input data set that needs to be checked - * - * @return True if labels are all zero or one, false otherwise. - */ - val classificationLabels: RDD[LabeledPoint] => Boolean = { data => - val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() - if (numInvalid != 0) { - logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels") - } - numInvalid == 0 - } -} diff --git a/mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala deleted file mode 100644 index 672b63f65a..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.util - -import scala.util.Random - -import spark.{RDD, SparkContext} - -/** - * Generate test data for KMeans. This class first chooses k cluster centers - * from a d-dimensional Gaussian distribution scaled by factor r and then creates a Gaussian - * cluster with scale 1 around each center. - */ - -object KMeansDataGenerator { - - /** - * Generate an RDD containing test data for KMeans. - * - * @param sc SparkContext to use for creating the RDD - * @param numPoints Number of points that will be contained in the RDD - * @param k Number of clusters - * @param d Number of dimensions - * @param r Scaling factor for the distribution of the initial centers - * @param numPartitions Number of partitions of the generated RDD; default 2 - */ - def generateKMeansRDD( - sc: SparkContext, - numPoints: Int, - k: Int, - d: Int, - r: Double, - numPartitions: Int = 2) - : RDD[Array[Double]] = - { - // First, generate some centers - val rand = new Random(42) - val centers = Array.fill(k)(Array.fill(d)(rand.nextGaussian() * r)) - // Then generate points around each center - sc.parallelize(0 until numPoints, numPartitions).map { idx => - val center = centers(idx % k) - val rand2 = new Random(42 + idx) - Array.tabulate(d)(i => center(i) + rand2.nextGaussian()) - } - } - - def main(args: Array[String]) { - if (args.length < 6) { - println("Usage: KMeansGenerator " + - " []") - System.exit(1) - } - - val sparkMaster = args(0) - val outputPath = args(1) - val numPoints = args(2).toInt - val k = args(3).toInt - val d = args(4).toInt - val r = args(5).toDouble - val parts = if (args.length >= 7) args(6).toInt else 2 - - val sc = new SparkContext(sparkMaster, "KMeansDataGenerator") - val data = generateKMeansRDD(sc, numPoints, k, d, r, parts) - data.map(_.mkString(" ")).saveAsTextFile(outputPath) - - System.exit(0) - } -} - diff --git a/mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala deleted file mode 100644 index 9f48477f84..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.util - -import scala.collection.JavaConversions._ -import scala.util.Random - -import org.jblas.DoubleMatrix - -import spark.{RDD, SparkContext} -import spark.mllib.regression.LabeledPoint -import spark.mllib.regression.LabeledPoint - -/** - * Generate sample data used for Linear Data. This class generates - * uniformly random values for every feature and adds Gaussian noise with mean `eps` to the - * response variable `Y`. - */ -object LinearDataGenerator { - - /** - * Return a Java List of synthetic data randomly generated according to a multi - * collinear model. - * @param intercept Data intercept - * @param weights Weights to be applied. - * @param nPoints Number of points in sample. - * @param seed Random seed - * @return Java List of input. - */ - def generateLinearInputAsList( - intercept: Double, - weights: Array[Double], - nPoints: Int, - seed: Int, - eps: Double): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps)) - } - - /** - * - * @param intercept Data intercept - * @param weights Weights to be applied. - * @param nPoints Number of points in sample. - * @param seed Random seed - * @param eps Epsilon scaling factor. - * @return - */ - def generateLinearInput( - intercept: Double, - weights: Array[Double], - nPoints: Int, - seed: Int, - eps: Double = 0.1): Seq[LabeledPoint] = { - - val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) - val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) - val y = x.map { xi => - (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + eps * rnd.nextGaussian() - } - y.zip(x).map(p => LabeledPoint(p._1, p._2)) - } - - /** - * Generate an RDD containing sample data for Linear Regression models - including Ridge, Lasso, - * and uregularized variants. - * - * @param sc SparkContext to be used for generating the RDD. - * @param nexamples Number of examples that will be contained in the RDD. - * @param nfeatures Number of features to generate for each example. - * @param eps Epsilon factor by which examples are scaled. - * @param weights Weights associated with the first weights.length features. - * @param nparts Number of partitions in the RDD. Default value is 2. - * - * @return RDD of LabeledPoint containing sample data. - */ - def generateLinearRDD( - sc: SparkContext, - nexamples: Int, - nfeatures: Int, - eps: Double, - nparts: Int = 2, - intercept: Double = 0.0) : RDD[LabeledPoint] = { - org.jblas.util.Random.seed(42) - // Random values distributed uniformly in [-0.5, 0.5] - val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) - - val data: RDD[LabeledPoint] = sc.parallelize(0 until nparts, nparts).flatMap { p => - val seed = 42 + p - val examplesInPartition = nexamples / nparts - generateLinearInput(intercept, w.toArray, examplesInPartition, seed, eps) - } - data - } - - def main(args: Array[String]) { - if (args.length < 2) { - println("Usage: LinearDataGenerator " + - " [num_examples] [num_features] [num_partitions]") - System.exit(1) - } - - val sparkMaster: String = args(0) - val outputPath: String = args(1) - val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 - val nfeatures: Int = if (args.length > 3) args(3).toInt else 100 - val parts: Int = if (args.length > 4) args(4).toInt else 2 - val eps = 10 - - val sc = new SparkContext(sparkMaster, "LinearDataGenerator") - val data = generateLinearRDD(sc, nexamples, nfeatures, eps, nparts = parts) - - MLUtils.saveLabeledData(data, outputPath) - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala deleted file mode 100644 index d6402f23e2..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.util - -import scala.util.Random - -import spark.{RDD, SparkContext} -import spark.mllib.regression.LabeledPoint - -/** - * Generate test data for LogisticRegression. This class chooses positive labels - * with probability `probOne` and scales features for positive examples by `eps`. - */ - -object LogisticRegressionDataGenerator { - - /** - * Generate an RDD containing test data for LogisticRegression. - * - * @param sc SparkContext to use for creating the RDD. - * @param nexamples Number of examples that will be contained in the RDD. - * @param nfeatures Number of features to generate for each example. - * @param eps Epsilon factor by which positive examples are scaled. - * @param nparts Number of partitions of the generated RDD. Default value is 2. - * @param probOne Probability that a label is 1 (and not 0). Default value is 0.5. - */ - def generateLogisticRDD( - sc: SparkContext, - nexamples: Int, - nfeatures: Int, - eps: Double, - nparts: Int = 2, - probOne: Double = 0.5): RDD[LabeledPoint] = { - val data = sc.parallelize(0 until nexamples, nparts).map { idx => - val rnd = new Random(42 + idx) - - val y = if (idx % 2 == 0) 0.0 else 1.0 - val x = Array.fill[Double](nfeatures) { - rnd.nextGaussian() + (y * eps) - } - LabeledPoint(y, x) - } - data - } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: LogisticRegressionGenerator " + - " ") - System.exit(1) - } - - val sparkMaster: String = args(0) - val outputPath: String = args(1) - val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 - val nfeatures: Int = if (args.length > 3) args(3).toInt else 2 - val parts: Int = if (args.length > 4) args(4).toInt else 2 - val eps = 3 - - val sc = new SparkContext(sparkMaster, "LogisticRegressionDataGenerator") - val data = generateLogisticRDD(sc, nexamples, nfeatures, eps, parts) - - MLUtils.saveLabeledData(data, outputPath) - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/MFDataGenerator.scala deleted file mode 100644 index 88992cde0c..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/MFDataGenerator.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.recommendation - -import scala.util.Random - -import org.jblas.DoubleMatrix - -import spark.{RDD, SparkContext} -import spark.mllib.util.MLUtils - -/** -* Generate RDD(s) containing data for Matrix Factorization. -* -* This method samples training entries according to the oversampling factor -* 'trainSampFact', which is a multiplicative factor of the number of -* degrees of freedom of the matrix: rank*(m+n-rank). -* -* It optionally samples entries for a testing matrix using -* 'testSampFact', the percentage of the number of training entries -* to use for testing. -* -* This method takes the following inputs: -* sparkMaster (String) The master URL. -* outputPath (String) Directory to save output. -* m (Int) Number of rows in data matrix. -* n (Int) Number of columns in data matrix. -* rank (Int) Underlying rank of data matrix. -* trainSampFact (Double) Oversampling factor. -* noise (Boolean) Whether to add gaussian noise to training data. -* sigma (Double) Standard deviation of added gaussian noise. -* test (Boolean) Whether to create testing RDD. -* testSampFact (Double) Percentage of training data to use as test data. -*/ - -object MFDataGenerator{ - - def main(args: Array[String]) { - if (args.length < 2) { - println("Usage: MFDataGenerator " + - " [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]") - System.exit(1) - } - - val sparkMaster: String = args(0) - val outputPath: String = args(1) - val m: Int = if (args.length > 2) args(2).toInt else 100 - val n: Int = if (args.length > 3) args(3).toInt else 100 - val rank: Int = if (args.length > 4) args(4).toInt else 10 - val trainSampFact: Double = if (args.length > 5) args(5).toDouble else 1.0 - val noise: Boolean = if (args.length > 6) args(6).toBoolean else false - val sigma: Double = if (args.length > 7) args(7).toDouble else 0.1 - val test: Boolean = if (args.length > 8) args(8).toBoolean else false - val testSampFact: Double = if (args.length > 9) args(9).toDouble else 0.1 - - val sc = new SparkContext(sparkMaster, "MFDataGenerator") - - val A = DoubleMatrix.randn(m, rank) - val B = DoubleMatrix.randn(rank, n) - val z = 1 / (scala.math.sqrt(scala.math.sqrt(rank))) - A.mmuli(z) - B.mmuli(z) - val fullData = A.mmul(B) - - val df = rank * (m + n - rank) - val sampSize = scala.math.min(scala.math.round(trainSampFact * df), - scala.math.round(.99 * m * n)).toInt - val rand = new Random() - val mn = m * n - val shuffled = rand.shuffle(1 to mn toIterable) - - val omega = shuffled.slice(0, sampSize) - val ordered = omega.sortWith(_ < _).toArray - val trainData: RDD[(Int, Int, Double)] = sc.parallelize(ordered) - .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) - - // optionally add gaussian noise - if (noise) { - trainData.map(x => (x._1, x._2, x._3 + rand.nextGaussian * sigma)) - } - - trainData.map(x => x._1 + "," + x._2 + "," + x._3).saveAsTextFile(outputPath) - - // optionally generate testing data - if (test) { - val testSampSize = scala.math - .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt - val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) - val testOrdered = testOmega.sortWith(_ < _).toArray - val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) - .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) - testData.map(x => x._1 + "," + x._2 + "," + x._3).saveAsTextFile(outputPath) - } - - sc.stop() - - } -} \ No newline at end of file diff --git a/mllib/src/main/scala/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/spark/mllib/util/MLUtils.scala deleted file mode 100644 index a8e6ae9953..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/MLUtils.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.util - -import spark.{RDD, SparkContext} -import spark.SparkContext._ - -import org.jblas.DoubleMatrix -import spark.mllib.regression.LabeledPoint - -/** - * Helper methods to load, save and pre-process data used in ML Lib. - */ -object MLUtils { - - /** - * Load labeled data from a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. - * - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is - * the label, and the second element represents the feature values (an array of Double). - */ - def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { - sc.textFile(dir).map { line => - val parts = line.split(',') - val label = parts(0).toDouble - val features = parts(1).trim().split(' ').map(_.toDouble) - LabeledPoint(label, features) - } - } - - /** - * Save labeled data to a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. - * - * @param data An RDD of LabeledPoints containing data to be saved. - * @param dir Directory to save the data. - */ - def saveLabeledData(data: RDD[LabeledPoint], dir: String) { - val dataStr = data.map(x => x.label + "," + x.features.mkString(" ")) - dataStr.saveAsTextFile(dir) - } - - /** - * Utility function to compute mean and standard deviation on a given dataset. - * - * @param data - input data set whose statistics are computed - * @param nfeatures - number of features - * @param nexamples - number of examples in input dataset - * - * @return (yMean, xColMean, xColSd) - Tuple consisting of - * yMean - mean of the labels - * xColMean - Row vector with mean for every column (or feature) of the input data - * xColSd - Row vector standard deviation for every column (or feature) of the input data. - */ - def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long): - (Double, DoubleMatrix, DoubleMatrix) = { - val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples - - // NOTE: We shuffle X by column here to compute column sum and sum of squares. - val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint => - val nCols = labeledPoint.features.length - // Traverse over every column and emit (col, value, value^2) - Iterator.tabulate(nCols) { i => - (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i))) - } - }.reduceByKey { case(x1, x2) => - (x1._1 + x2._1, x1._2 + x2._2) - } - val xColSumsMap = xColSumSq.collectAsMap() - - val xColMean = DoubleMatrix.zeros(nfeatures, 1) - val xColSd = DoubleMatrix.zeros(nfeatures, 1) - - // Compute mean and unbiased variance using column sums - var col = 0 - while (col < nfeatures) { - xColMean.put(col, xColSumsMap(col)._1 / nexamples) - val variance = - (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / (nexamples) - xColSd.put(col, math.sqrt(variance)) - col += 1 - } - - (yMean, xColMean, xColSd) - } - - /** - * Return the squared Euclidean distance between two vectors. - */ - def squaredDistance(v1: Array[Double], v2: Array[Double]): Double = { - if (v1.length != v2.length) { - throw new IllegalArgumentException("Vector sizes don't match") - } - var i = 0 - var sum = 0.0 - while (i < v1.length) { - sum += (v1(i) - v2(i)) * (v1(i) - v2(i)) - i += 1 - } - sum - } -} diff --git a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala deleted file mode 100644 index eff456cad6..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala +++ /dev/null @@ -1,50 +0,0 @@ -package spark.mllib.util - -import scala.util.Random - -import spark.{RDD, SparkContext} - -import org.jblas.DoubleMatrix -import spark.mllib.regression.LabeledPoint - -/** - * Generate sample data used for SVM. This class generates uniform random values - * for the features and adds Gaussian noise with weight 0.1 to generate labels. - */ -object SVMDataGenerator { - - def main(args: Array[String]) { - if (args.length < 2) { - println("Usage: SVMGenerator " + - " [num_examples] [num_features] [num_partitions]") - System.exit(1) - } - - val sparkMaster: String = args(0) - val outputPath: String = args(1) - val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 - val nfeatures: Int = if (args.length > 3) args(3).toInt else 2 - val parts: Int = if (args.length > 4) args(4).toInt else 2 - - val sc = new SparkContext(sparkMaster, "SVMGenerator") - - val globalRnd = new Random(94720) - val trueWeights = new DoubleMatrix(1, nfeatures + 1, - Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()):_*) - - val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx => - val rnd = new Random(42 + idx) - - val x = Array.fill[Double](nfeatures) { - rnd.nextDouble() * 2.0 - 1.0 - } - val yD = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1 - val y = if (yD < 0) 0.0 else 1.0 - LabeledPoint(y, x) - } - - MLUtils.saveLabeledData(data, outputPath) - - sc.stop() - } -} diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java new file mode 100644 index 0000000000..e18e3bc6a8 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +import org.apache.spark.mllib.regression.LabeledPoint; + +public class JavaLogisticRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List validationData, LogisticRegressionModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + if (prediction == point.label()) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runLRUsingConstructor() { + int nPoints = 10000; + double A = 2.0; + double B = -1.5; + + JavaRDD testRDD = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + List validationData = + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); + + LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(); + lrImpl.optimizer().setStepSize(1.0) + .setRegParam(1.0) + .setNumIterations(100); + LogisticRegressionModel model = lrImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runLRUsingStaticMethods() { + int nPoints = 10000; + double A = 2.0; + double B = -1.5; + + JavaRDD testRDD = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + List validationData = + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); + + LogisticRegressionModel model = LogisticRegressionWithSGD.train( + testRDD.rdd(), 100, 1.0, 1.0); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java new file mode 100644 index 0000000000..117e5eaa8b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification; + + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +import org.apache.spark.mllib.regression.LabeledPoint; + +public class JavaSVMSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaSVMSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List validationData, SVMModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + if (prediction == point.label()) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runSVMUsingConstructor() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0}; + + JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, + weights, nPoints, 42), 2).cache(); + List validationData = + SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); + + SVMWithSGD svmSGDImpl = new SVMWithSGD(); + svmSGDImpl.optimizer().setStepSize(1.0) + .setRegParam(1.0) + .setNumIterations(100); + SVMModel model = svmSGDImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runSVMUsingStaticMethods() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0}; + + JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, + weights, nPoints, 42), 2).cache(); + List validationData = + SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); + + SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java new file mode 100644 index 0000000000..32d3934ac1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaKMeansSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaKMeans"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + // L1 distance between two points + double distance1(double[] v1, double[] v2) { + double distance = 0.0; + for (int i = 0; i < v1.length; ++i) { + distance = Math.max(distance, Math.abs(v1[i] - v2[i])); + } + return distance; + } + + // Assert that two sets of points are equal, within EPSILON tolerance + void assertSetsEqual(double[][] v1, double[][] v2) { + double EPSILON = 1e-4; + Assert.assertTrue(v1.length == v2.length); + for (int i = 0; i < v1.length; ++i) { + double minDistance = Double.MAX_VALUE; + for (int j = 0; j < v2.length; ++j) { + minDistance = Math.min(minDistance, distance1(v1[i], v2[j])); + } + Assert.assertTrue(minDistance <= EPSILON); + } + + for (int i = 0; i < v2.length; ++i) { + double minDistance = Double.MAX_VALUE; + for (int j = 0; j < v1.length; ++j) { + minDistance = Math.min(minDistance, distance1(v2[i], v1[j])); + } + Assert.assertTrue(minDistance <= EPSILON); + } + } + + + @Test + public void runKMeansUsingStaticMethods() { + List points = new ArrayList(); + points.add(new double[]{1.0, 2.0, 6.0}); + points.add(new double[]{1.0, 3.0, 0.0}); + points.add(new double[]{1.0, 4.0, 6.0}); + + double[][] expectedCenter = { {1.0, 3.0, 4.0} }; + + JavaRDD data = sc.parallelize(points, 2); + KMeansModel model = KMeans.train(data.rdd(), 1, 1); + assertSetsEqual(model.clusterCenters(), expectedCenter); + + model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.RANDOM()); + assertSetsEqual(model.clusterCenters(), expectedCenter); + } + + @Test + public void runKMeansUsingConstructor() { + List points = new ArrayList(); + points.add(new double[]{1.0, 2.0, 6.0}); + points.add(new double[]{1.0, 3.0, 0.0}); + points.add(new double[]{1.0, 4.0, 6.0}); + + double[][] expectedCenter = { {1.0, 3.0, 4.0} }; + + JavaRDD data = sc.parallelize(points, 2); + KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); + assertSetsEqual(model.clusterCenters(), expectedCenter); + + model = new KMeans().setK(1) + .setMaxIterations(1) + .setRuns(1) + .setInitializationMode(KMeans.RANDOM()) + .run(data.rdd()); + assertSetsEqual(model.clusterCenters(), expectedCenter); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java new file mode 100644 index 0000000000..3323f6cee2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +import org.jblas.DoubleMatrix; + +public class JavaALSSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaALS"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, + DoubleMatrix trueRatings, double matchThreshold) { + DoubleMatrix predictedU = new DoubleMatrix(users, features); + List> userFeatures = model.userFeatures().toJavaRDD().collect(); + for (int i = 0; i < features; ++i) { + for (scala.Tuple2 userFeature : userFeatures) { + predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); + } + } + DoubleMatrix predictedP = new DoubleMatrix(products, features); + + List> productFeatures = + model.productFeatures().toJavaRDD().collect(); + for (int i = 0; i < features; ++i) { + for (scala.Tuple2 productFeature : productFeatures) { + predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); + } + } + + DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); + + for (int u = 0; u < users; ++u) { + for (int p = 0; p < products; ++p) { + double prediction = predictedRatings.get(u, p); + double correct = trueRatings.get(u, p); + Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold); + } + } + } + + @Test + public void runALSUsingStaticMethods() { + int features = 1; + int iterations = 15; + int users = 10; + int products = 10; + scala.Tuple2, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7); + + JavaRDD data = sc.parallelize(testData._1()); + MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); + validatePrediction(model, users, products, features, testData._2(), 0.3); + } + + @Test + public void runALSUsingConstructor() { + int features = 2; + int iterations = 15; + int users = 20; + int products = 30; + scala.Tuple2, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7); + + JavaRDD data = sc.parallelize(testData._1()); + + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .run(data.rdd()); + validatePrediction(model, users, products, features, testData._2(), 0.3); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java new file mode 100644 index 0000000000..f44b25cd44 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.util.LinearDataGenerator; + +public class JavaLassoSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLassoSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List validationData, LassoModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + // A prediction is off if the prediction is more than 0.5 away from expected value. + if (Math.abs(prediction - point.label()) <= 0.5) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runLassoUsingConstructor() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0e-2}; + + JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, + weights, nPoints, 42, 0.1), 2).cache(); + List validationData = + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + + LassoWithSGD lassoSGDImpl = new LassoWithSGD(); + lassoSGDImpl.optimizer().setStepSize(1.0) + .setRegParam(0.01) + .setNumIterations(20); + LassoModel model = lassoSGDImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runLassoUsingStaticMethods() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0e-2}; + + JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, + weights, nPoints, 42, 0.1), 2).cache(); + List validationData = + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + + LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java new file mode 100644 index 0000000000..5a4410a632 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.util.LinearDataGenerator; + +public class JavaLinearRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List validationData, LinearRegressionModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + // A prediction is off if the prediction is more than 0.5 away from expected value. + if (Math.abs(prediction - point.label()) <= 0.5) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runLinearRegressionUsingConstructor() { + int nPoints = 100; + double A = 3.0; + double[] weights = {10, 10}; + + JavaRDD testRDD = sc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + List validationData = + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + + LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); + LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runLinearRegressionUsingStaticMethods() { + int nPoints = 100; + double A = 3.0; + double[] weights = {10, 10}; + + JavaRDD testRDD = sc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + List validationData = + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + + LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java new file mode 100644 index 0000000000..2fdd5fc8fd --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.jblas.DoubleMatrix; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.util.LinearDataGenerator; + +public class JavaRidgeRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + double predictionError(List validationData, RidgeRegressionModel model) { + double errorSum = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + errorSum += (prediction - point.label()) * (prediction - point.label()); + } + return errorSum / validationData.size(); + } + + List generateRidgeData(int numPoints, int nfeatures, double eps) { + org.jblas.util.Random.seed(42); + // Pick weights as random values distributed uniformly in [-0.5, 0.5] + DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5); + // Set first two weights to eps + w.put(0, 0, eps); + w.put(1, 0, eps); + return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps); + } + + @Test + public void runRidgeRegressionUsingConstructor() { + int nexamples = 200; + int nfeatures = 20; + double eps = 10.0; + List data = generateRidgeData(2*nexamples, nfeatures, eps); + + JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); + List validationData = data.subList(nexamples, 2*nexamples); + + RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); + ridgeSGDImpl.optimizer().setStepSize(1.0) + .setRegParam(0.0) + .setNumIterations(200); + RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd()); + double unRegularizedErr = predictionError(validationData, model); + + ridgeSGDImpl.optimizer().setRegParam(0.1); + model = ridgeSGDImpl.run(testRDD.rdd()); + double regularizedErr = predictionError(validationData, model); + + Assert.assertTrue(regularizedErr < unRegularizedErr); + } + + @Test + public void runRidgeRegressionUsingStaticMethods() { + int nexamples = 200; + int nfeatures = 20; + double eps = 10.0; + List data = generateRidgeData(2*nexamples, nfeatures, eps); + + JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); + List validationData = data.subList(nexamples, 2*nexamples); + + RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); + double unRegularizedErr = predictionError(validationData, model); + + model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.1); + double regularizedErr = predictionError(validationData, model); + + Assert.assertTrue(regularizedErr < unRegularizedErr); + } +} diff --git a/mllib/src/test/java/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/spark/mllib/classification/JavaLogisticRegressionSuite.java deleted file mode 100644 index e0ebd45cd8..0000000000 --- a/mllib/src/test/java/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification; - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; - -import spark.mllib.regression.LabeledPoint; - -public class JavaLogisticRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - int validatePrediction(List validationData, LogisticRegressionModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - if (prediction == point.label()) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runLRUsingConstructor() { - int nPoints = 10000; - double A = 2.0; - double B = -1.5; - - JavaRDD testRDD = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); - List validationData = - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); - - LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(); - lrImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); - LogisticRegressionModel model = lrImpl.run(testRDD.rdd()); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - - @Test - public void runLRUsingStaticMethods() { - int nPoints = 10000; - double A = 2.0; - double B = -1.5; - - JavaRDD testRDD = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); - List validationData = - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); - - LogisticRegressionModel model = LogisticRegressionWithSGD.train( - testRDD.rdd(), 100, 1.0, 1.0); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - -} diff --git a/mllib/src/test/java/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/spark/mllib/classification/JavaSVMSuite.java deleted file mode 100644 index 7881b3c38f..0000000000 --- a/mllib/src/test/java/spark/mllib/classification/JavaSVMSuite.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification; - - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; - -import spark.mllib.regression.LabeledPoint; - -public class JavaSVMSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaSVMSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - int validatePrediction(List validationData, SVMModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - if (prediction == point.label()) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runSVMUsingConstructor() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0}; - - JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, - weights, nPoints, 42), 2).cache(); - List validationData = - SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); - - SVMWithSGD svmSGDImpl = new SVMWithSGD(); - svmSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); - SVMModel model = svmSGDImpl.run(testRDD.rdd()); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - - @Test - public void runSVMUsingStaticMethods() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0}; - - JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, - weights, nPoints, 42), 2).cache(); - List validationData = - SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); - - SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - -} diff --git a/mllib/src/test/java/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/spark/mllib/clustering/JavaKMeansSuite.java deleted file mode 100644 index 3f2d82bfb4..0000000000 --- a/mllib/src/test/java/spark/mllib/clustering/JavaKMeansSuite.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.clustering; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; - -public class JavaKMeansSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaKMeans"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - // L1 distance between two points - double distance1(double[] v1, double[] v2) { - double distance = 0.0; - for (int i = 0; i < v1.length; ++i) { - distance = Math.max(distance, Math.abs(v1[i] - v2[i])); - } - return distance; - } - - // Assert that two sets of points are equal, within EPSILON tolerance - void assertSetsEqual(double[][] v1, double[][] v2) { - double EPSILON = 1e-4; - Assert.assertTrue(v1.length == v2.length); - for (int i = 0; i < v1.length; ++i) { - double minDistance = Double.MAX_VALUE; - for (int j = 0; j < v2.length; ++j) { - minDistance = Math.min(minDistance, distance1(v1[i], v2[j])); - } - Assert.assertTrue(minDistance <= EPSILON); - } - - for (int i = 0; i < v2.length; ++i) { - double minDistance = Double.MAX_VALUE; - for (int j = 0; j < v1.length; ++j) { - minDistance = Math.min(minDistance, distance1(v2[i], v1[j])); - } - Assert.assertTrue(minDistance <= EPSILON); - } - } - - - @Test - public void runKMeansUsingStaticMethods() { - List points = new ArrayList(); - points.add(new double[]{1.0, 2.0, 6.0}); - points.add(new double[]{1.0, 3.0, 0.0}); - points.add(new double[]{1.0, 4.0, 6.0}); - - double[][] expectedCenter = { {1.0, 3.0, 4.0} }; - - JavaRDD data = sc.parallelize(points, 2); - KMeansModel model = KMeans.train(data.rdd(), 1, 1); - assertSetsEqual(model.clusterCenters(), expectedCenter); - - model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.RANDOM()); - assertSetsEqual(model.clusterCenters(), expectedCenter); - } - - @Test - public void runKMeansUsingConstructor() { - List points = new ArrayList(); - points.add(new double[]{1.0, 2.0, 6.0}); - points.add(new double[]{1.0, 3.0, 0.0}); - points.add(new double[]{1.0, 4.0, 6.0}); - - double[][] expectedCenter = { {1.0, 3.0, 4.0} }; - - JavaRDD data = sc.parallelize(points, 2); - KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); - assertSetsEqual(model.clusterCenters(), expectedCenter); - - model = new KMeans().setK(1) - .setMaxIterations(1) - .setRuns(1) - .setInitializationMode(KMeans.RANDOM()) - .run(data.rdd()); - assertSetsEqual(model.clusterCenters(), expectedCenter); - } -} diff --git a/mllib/src/test/java/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/spark/mllib/recommendation/JavaALSSuite.java deleted file mode 100644 index 7993629a6d..0000000000 --- a/mllib/src/test/java/spark/mllib/recommendation/JavaALSSuite.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.recommendation; - -import java.io.Serializable; -import java.util.List; - -import scala.Tuple2; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; - -import org.jblas.DoubleMatrix; - -public class JavaALSSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaALS"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, - DoubleMatrix trueRatings, double matchThreshold) { - DoubleMatrix predictedU = new DoubleMatrix(users, features); - List> userFeatures = model.userFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (scala.Tuple2 userFeature : userFeatures) { - predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); - } - } - DoubleMatrix predictedP = new DoubleMatrix(products, features); - - List> productFeatures = - model.productFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (scala.Tuple2 productFeature : productFeatures) { - predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); - } - } - - DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); - - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double correct = trueRatings.get(u, p); - Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold); - } - } - } - - @Test - public void runALSUsingStaticMethods() { - int features = 1; - int iterations = 15; - int users = 10; - int products = 10; - scala.Tuple2, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7); - - JavaRDD data = sc.parallelize(testData._1()); - MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.3); - } - - @Test - public void runALSUsingConstructor() { - int features = 2; - int iterations = 15; - int users = 20; - int products = 30; - scala.Tuple2, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7); - - JavaRDD data = sc.parallelize(testData._1()); - - MatrixFactorizationModel model = new ALS().setRank(features) - .setIterations(iterations) - .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.3); - } -} diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java deleted file mode 100644 index 5863140baf..0000000000 --- a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression; - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.mllib.util.LinearDataGenerator; - -public class JavaLassoSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLassoSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - int validatePrediction(List validationData, LassoModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - // A prediction is off if the prediction is more than 0.5 away from expected value. - if (Math.abs(prediction - point.label()) <= 0.5) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runLassoUsingConstructor() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0e-2}; - - JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42, 0.1), 2).cache(); - List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - - LassoWithSGD lassoSGDImpl = new LassoWithSGD(); - lassoSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.01) - .setNumIterations(20); - LassoModel model = lassoSGDImpl.run(testRDD.rdd()); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - - @Test - public void runLassoUsingStaticMethods() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0e-2}; - - JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42, 0.1), 2).cache(); - List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - - LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - -} diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java deleted file mode 100644 index 50716c7861..0000000000 --- a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression; - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.mllib.util.LinearDataGenerator; - -public class JavaLinearRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - int validatePrediction(List validationData, LinearRegressionModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - // A prediction is off if the prediction is more than 0.5 away from expected value. - if (Math.abs(prediction - point.label()) <= 0.5) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runLinearRegressionUsingConstructor() { - int nPoints = 100; - double A = 3.0; - double[] weights = {10, 10}; - - JavaRDD testRDD = sc.parallelize( - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); - List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - - LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); - LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - - @Test - public void runLinearRegressionUsingStaticMethods() { - int nPoints = 100; - double A = 3.0; - double[] weights = {10, 10}; - - JavaRDD testRDD = sc.parallelize( - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); - List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - - LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - -} diff --git a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java deleted file mode 100644 index 2c0aabad30..0000000000 --- a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression; - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.jblas.DoubleMatrix; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.mllib.util.LinearDataGenerator; - -public class JavaRidgeRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - double predictionError(List validationData, RidgeRegressionModel model) { - double errorSum = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - errorSum += (prediction - point.label()) * (prediction - point.label()); - } - return errorSum / validationData.size(); - } - - List generateRidgeData(int numPoints, int nfeatures, double eps) { - org.jblas.util.Random.seed(42); - // Pick weights as random values distributed uniformly in [-0.5, 0.5] - DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5); - // Set first two weights to eps - w.put(0, 0, eps); - w.put(1, 0, eps); - return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps); - } - - @Test - public void runRidgeRegressionUsingConstructor() { - int nexamples = 200; - int nfeatures = 20; - double eps = 10.0; - List data = generateRidgeData(2*nexamples, nfeatures, eps); - - JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); - List validationData = data.subList(nexamples, 2*nexamples); - - RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); - ridgeSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.0) - .setNumIterations(200); - RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd()); - double unRegularizedErr = predictionError(validationData, model); - - ridgeSGDImpl.optimizer().setRegParam(0.1); - model = ridgeSGDImpl.run(testRDD.rdd()); - double regularizedErr = predictionError(validationData, model); - - Assert.assertTrue(regularizedErr < unRegularizedErr); - } - - @Test - public void runRidgeRegressionUsingStaticMethods() { - int nexamples = 200; - int nfeatures = 20; - double eps = 10.0; - List data = generateRidgeData(2*nexamples, nfeatures, eps); - - JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); - List validationData = data.subList(nexamples, 2*nexamples); - - RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); - double unRegularizedErr = predictionError(validationData, model); - - model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.1); - double regularizedErr = predictionError(validationData, model); - - Assert.assertTrue(regularizedErr < unRegularizedErr); - } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala new file mode 100644 index 0000000000..34c67294e9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.util.Random +import scala.collection.JavaConversions._ + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.regression._ + +object LogisticRegressionSuite { + + def generateLogisticInputAsList( + offset: Double, + scale: Double, + nPoints: Int, + seed: Int): java.util.List[LabeledPoint] = { + seqAsJavaList(generateLogisticInput(offset, scale, nPoints, seed)) + } + + // Generate input of the form Y = logistic(offset + scale*X) + def generateLogisticInput( + offset: Double, + scale: Double, + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) + + // NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1) + val unifRand = new scala.util.Random(45) + val rLogis = (0 until nPoints).map { i => + val u = unifRand.nextDouble() + math.log(u) - math.log(1.0-u) + } + + // y <- A + B*x + rLogis() + // y <- as.numeric(y > 0) + val y: Seq[Int] = (0 until nPoints).map { i => + val yVal = offset + scale * x1(i) + rLogis(i) + if (yVal > 0) 1 else 0 + } + + val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i)))) + testData + } + +} + +class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => + (prediction != expected.label) + }.size + // At least 83% of the predictions should be on. + ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83 + } + + // Test if we can correctly learn A, B where Y = logistic(A + B*X) + test("logistic regression") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + val lr = new LogisticRegressionWithSGD() + lr.optimizer.setStepSize(10.0).setNumIterations(20) + + val model = lr.run(testRDD) + + // Test the weights + val weight0 = model.weights(0) + assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") + assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("logistic regression with initial weights") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialB = -1.0 + val initialWeights = Array(initialB) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + // Use half as many iterations as the previous test. + val lr = new LogisticRegressionWithSGD() + lr.optimizer.setStepSize(10.0).setNumIterations(10) + + val model = lr.run(testRDD, initialWeights) + + val weight0 = model.weights(0) + assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") + assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala new file mode 100644 index 0000000000..6a957e3ddc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.util.Random +import scala.math.signum +import scala.collection.JavaConversions._ + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.jblas.DoubleMatrix + +import org.apache.spark.{SparkException, SparkContext} +import org.apache.spark.mllib.regression._ + +object SVMSuite { + + def generateSVMInputAsList( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int): java.util.List[LabeledPoint] = { + seqAsJavaList(generateSVMInput(intercept, weights, nPoints, seed)) + } + + // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) + def generateSVMInput( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val x = Array.fill[Array[Double]](nPoints)( + Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) + val y = x.map { xi => + val yD = (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + + intercept + 0.01 * rnd.nextGaussian() + if (yD < 0) 0.0 else 1.0 + } + y.zip(x).map(p => LabeledPoint(p._1, p._2)) + } + +} + +class SVMSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => + (prediction != expected.label) + }.size + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + + test("SVM using local random SGD") { + val nPoints = 10000 + + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 + val B = -1.5 + val C = 1.0 + + val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val svm = new SVMWithSGD() + svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100) + + val model = svm.run(testRDD) + + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("SVM local random SGD with initial weights") { + val nPoints = 10000 + + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 + val B = -1.5 + val C = 1.0 + + val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + + val initialB = -1.0 + val initialC = -1.0 + val initialWeights = Array(initialB,initialC) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val svm = new SVMWithSGD() + svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100) + + val model = svm.run(testRDD, initialWeights) + + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) + val validationRDD = sc.parallelize(validationData,2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("SVM with invalid labels") { + val nPoints = 10000 + + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 + val B = -1.5 + val C = 1.0 + + val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testRDD = sc.parallelize(testData, 2) + + val testRDDInvalid = testRDD.map { lp => + if (lp.label == 0.0) { + LabeledPoint(-1.0, lp.features) + } else { + lp + } + } + + intercept[SparkException] { + val model = SVMWithSGD.train(testRDDInvalid, 100) + } + + // Turning off data validation should not throw an exception + val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala new file mode 100644 index 0000000000..94245f6027 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ + +import org.jblas._ + +class KMeansSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + val EPSILON = 1e-4 + + import KMeans.{RANDOM, K_MEANS_PARALLEL} + + def prettyPrint(point: Array[Double]): String = point.mkString("(", ", ", ")") + + def prettyPrint(points: Array[Array[Double]]): String = { + points.map(prettyPrint).mkString("(", "; ", ")") + } + + // L1 distance between two points + def distance1(v1: Array[Double], v2: Array[Double]): Double = { + v1.zip(v2).map{ case (a, b) => math.abs(a-b) }.max + } + + // Assert that two vectors are equal within tolerance EPSILON + def assertEqual(v1: Array[Double], v2: Array[Double]) { + def errorMessage = prettyPrint(v1) + " did not equal " + prettyPrint(v2) + assert(v1.length == v2.length, errorMessage) + assert(distance1(v1, v2) <= EPSILON, errorMessage) + } + + // Assert that two sets of points are equal, within EPSILON tolerance + def assertSetsEqual(set1: Array[Array[Double]], set2: Array[Array[Double]]) { + def errorMessage = prettyPrint(set1) + " did not equal " + prettyPrint(set2) + assert(set1.length == set2.length, errorMessage) + for (v <- set1) { + val closestDistance = set2.map(w => distance1(v, w)).min + if (closestDistance > EPSILON) { + fail(errorMessage) + } + } + for (v <- set2) { + val closestDistance = set1.map(w => distance1(v, w)).min + if (closestDistance > EPSILON) { + fail(errorMessage) + } + } + } + + test("single cluster") { + val data = sc.parallelize(Array( + Array(1.0, 2.0, 6.0), + Array(1.0, 3.0, 0.0), + Array(1.0, 4.0, 6.0) + )) + + // No matter how many runs or iterations we use, we should get one cluster, + // centered at the mean of the points + + var model = KMeans.train(data, k=1, maxIterations=1) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=2) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train( + data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + } + + test("single cluster with big dataset") { + val smallData = Array( + Array(1.0, 2.0, 6.0), + Array(1.0, 3.0, 0.0), + Array(1.0, 4.0, 6.0) + ) + val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4) + + // No matter how many runs or iterations we use, we should get one cluster, + // centered at the mean of the points + + var model = KMeans.train(data, k=1, maxIterations=1) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=2) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + } + + test("k-means|| initialization") { + val points = Array( + Array(1.0, 2.0, 6.0), + Array(1.0, 3.0, 0.0), + Array(1.0, 4.0, 6.0), + Array(1.0, 0.0, 1.0), + Array(1.0, 1.0, 1.0) + ) + val rdd = sc.parallelize(points) + + // K-means|| initialization should place all clusters into distinct centers because + // it will make at least five passes, and it will give non-zero probability to each + // unselected point as long as it hasn't yet selected all of them + + var model = KMeans.train(rdd, k=5, maxIterations=1) + assertSetsEqual(model.clusterCenters, points) + + // Iterations of Lloyd's should not change the answer either + model = KMeans.train(rdd, k=5, maxIterations=10) + assertSetsEqual(model.clusterCenters, points) + + // Neither should more runs + model = KMeans.train(rdd, k=5, maxIterations=10, runs=5) + assertSetsEqual(model.clusterCenters, points) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala new file mode 100644 index 0000000000..347ef238f4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ + +import org.jblas._ + +object ALSSuite { + + def generateRatingsAsJavaList( + users: Int, + products: Int, + features: Int, + samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = { + val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate) + (seqAsJavaList(sampledRatings), trueRatings) + } + + def generateRatings( + users: Int, + products: Int, + features: Int, + samplingRate: Double): (Seq[Rating], DoubleMatrix) = { + val rand = new Random(42) + + // Create a random matrix with uniform values from -1 to 1 + def randomMatrix(m: Int, n: Int) = + new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*) + + val userMatrix = randomMatrix(users, features) + val productMatrix = randomMatrix(features, products) + val trueRatings = userMatrix.mmul(productMatrix) + + val sampledRatings = { + for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate) + yield Rating(u, p, trueRatings.get(u, p)) + } + + (sampledRatings, trueRatings) + } + +} + + +class ALSSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + test("rank-1 matrices") { + testALS(10, 20, 1, 15, 0.7, 0.3) + } + + test("rank-2 matrices") { + testALS(20, 30, 2, 15, 0.7, 0.3) + } + + /** + * Test if we can correctly factorize R = U * P where U and P are of known rank. + * + * @param users number of users + * @param products number of products + * @param features number of features (rank of problem) + * @param iterations number of iterations to run + * @param samplingRate what fraction of the user-product pairs are known + * @param matchThreshold max difference allowed to consider a predicted rating correct + */ + def testALS(users: Int, products: Int, features: Int, iterations: Int, + samplingRate: Double, matchThreshold: Double) + { + val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products, + features, samplingRate) + val model = ALS.train(sc.parallelize(sampledRatings), features, iterations) + + val predictedU = new DoubleMatrix(users, features) + for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) { + predictedU.put(u, i, vec(i)) + } + val predictedP = new DoubleMatrix(products, features) + for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) { + predictedP.put(p, i, vec(i)) + } + val predictedRatings = predictedU.mmul(predictedP.transpose) + + for (u <- 0 until users; p <- 0 until products) { + val prediction = predictedRatings.get(u, p) + val correct = trueRatings.get(u, p) + if (math.abs(prediction - correct) > matchThreshold) { + fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( + u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP)) + } + } + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala new file mode 100644 index 0000000000..db980c7bae --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.util.LinearDataGenerator + + +class LassoSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => + // A prediction is off if the prediction is more than 0.5 away from expected value. + math.abs(prediction - expected.label) > 0.5 + }.size + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + test("Lasso local random SGD") { + val nPoints = 10000 + + val A = 2.0 + val B = -1.5 + val C = 1.0e-2 + + val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val ls = new LassoWithSGD() + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) + + val model = ls.run(testRDD) + + val weight0 = model.weights(0) + val weight1 = model.weights(1) + assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") + assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") + + val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("Lasso local random SGD with initial weights") { + val nPoints = 10000 + + val A = 2.0 + val B = -1.5 + val C = 1.0e-2 + + val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) + + val initialB = -1.0 + val initialC = -1.0 + val initialWeights = Array(initialB,initialC) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val ls = new LassoWithSGD() + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) + + val model = ls.run(testRDD, initialWeights) + + val weight0 = model.weights(0) + val weight1 = model.weights(1) + assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") + assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") + + val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationRDD = sc.parallelize(validationData,2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala new file mode 100644 index 0000000000..ef500c704c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.LinearDataGenerator + +class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => + // A prediction is off if the prediction is more than 0.5 away from expected value. + math.abs(prediction - expected.label) > 0.5 + }.size + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + // Test if we can correctly learn Y = 3 + 10*X1 + 10*X2 + test("linear regression") { + val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput( + 3.0, Array(10.0, 10.0), 100, 42), 2).cache() + val linReg = new LinearRegressionWithSGD() + linReg.optimizer.setNumIterations(1000).setStepSize(1.0) + + val model = linReg.run(testRDD) + + assert(model.intercept >= 2.5 && model.intercept <= 3.5) + assert(model.weights.length === 2) + assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) + assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + + val validationData = LinearDataGenerator.generateLinearInput( + 3.0, Array(10.0, 10.0), 100, 17) + val validationRDD = sc.parallelize(validationData, 2).cache() + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala new file mode 100644 index 0000000000..c18092d804 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.jblas.DoubleMatrix +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.LinearDataGenerator + +class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { + predictions.zip(input).map { case (prediction, expected) => + (prediction - expected.label) * (prediction - expected.label) + }.reduceLeft(_ + _) / predictions.size + } + + test("regularization with skewed weights") { + val nexamples = 200 + val nfeatures = 20 + val eps = 10 + + org.jblas.util.Random.seed(42) + // Pick weights as random values distributed uniformly in [-0.5, 0.5] + val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) + // Set first two weights to eps + w.put(0, 0, eps) + w.put(1, 0, eps) + + // Use half of data for training and other half for validation + val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps) + val testData = data.take(nexamples) + val validationData = data.takeRight(nexamples) + + val testRDD = sc.parallelize(testData, 2).cache() + val validationRDD = sc.parallelize(validationData, 2).cache() + + // First run without regularization. + val linearReg = new LinearRegressionWithSGD() + linearReg.optimizer.setNumIterations(200) + .setStepSize(1.0) + + val linearModel = linearReg.run(testRDD) + val linearErr = predictionError( + linearModel.predict(validationRDD.map(_.features)).collect(), validationData) + + val ridgeReg = new RidgeRegressionWithSGD() + ridgeReg.optimizer.setNumIterations(200) + .setRegParam(0.1) + .setStepSize(1.0) + val ridgeModel = ridgeReg.run(testRDD) + val ridgeErr = predictionError( + ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData) + + // Ridge CV-error should be lower than linear regression + assert(ridgeErr < linearErr, + "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") + } +} diff --git a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala deleted file mode 100644 index bd87c528c3..0000000000 --- a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification - -import scala.util.Random -import scala.collection.JavaConversions._ - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers - -import spark.SparkContext -import spark.mllib.regression._ - -object LogisticRegressionSuite { - - def generateLogisticInputAsList( - offset: Double, - scale: Double, - nPoints: Int, - seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLogisticInput(offset, scale, nPoints, seed)) - } - - // Generate input of the form Y = logistic(offset + scale*X) - def generateLogisticInput( - offset: Double, - scale: Double, - nPoints: Int, - seed: Int): Seq[LabeledPoint] = { - val rnd = new Random(seed) - val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) - - // NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1) - val unifRand = new scala.util.Random(45) - val rLogis = (0 until nPoints).map { i => - val u = unifRand.nextDouble() - math.log(u) - math.log(1.0-u) - } - - // y <- A + B*x + rLogis() - // y <- as.numeric(y > 0) - val y: Seq[Int] = (0 until nPoints).map { i => - val yVal = offset + scale * x1(i) + rLogis(i) - if (yVal > 0) 1 else 0 - } - - val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i)))) - testData - } - -} - -class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { - val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => - (prediction != expected.label) - }.size - // At least 83% of the predictions should be on. - ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83 - } - - // Test if we can correctly learn A, B where Y = logistic(A + B*X) - test("logistic regression") { - val nPoints = 10000 - val A = 2.0 - val B = -1.5 - - val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - val lr = new LogisticRegressionWithSGD() - lr.optimizer.setStepSize(10.0).setNumIterations(20) - - val model = lr.run(testRDD) - - // Test the weights - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - - val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } - - test("logistic regression with initial weights") { - val nPoints = 10000 - val A = 2.0 - val B = -1.5 - - val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) - - val initialB = -1.0 - val initialWeights = Array(initialB) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - // Use half as many iterations as the previous test. - val lr = new LogisticRegressionWithSGD() - lr.optimizer.setStepSize(10.0).setNumIterations(10) - - val model = lr.run(testRDD, initialWeights) - - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - - val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } -} diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala deleted file mode 100644 index 894ae458ad..0000000000 --- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification - -import scala.util.Random -import scala.math.signum -import scala.collection.JavaConversions._ - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.mllib.regression._ - -import org.jblas.DoubleMatrix - -object SVMSuite { - - def generateSVMInputAsList( - intercept: Double, - weights: Array[Double], - nPoints: Int, - seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateSVMInput(intercept, weights, nPoints, seed)) - } - - // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) - def generateSVMInput( - intercept: Double, - weights: Array[Double], - nPoints: Int, - seed: Int): Seq[LabeledPoint] = { - val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) - val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) - val y = x.map { xi => - val yD = (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + - intercept + 0.01 * rnd.nextGaussian() - if (yD < 0) 0.0 else 1.0 - } - y.zip(x).map(p => LabeledPoint(p._1, p._2)) - } - -} - -class SVMSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { - val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => - (prediction != expected.label) - }.size - // At least 80% of the predictions should be on. - assert(numOffPredictions < input.length / 5) - } - - - test("SVM using local random SGD") { - val nPoints = 10000 - - // NOTE: Intercept should be small for generating equal 0s and 1s - val A = 0.01 - val B = -1.5 - val C = 1.0 - - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - val svm = new SVMWithSGD() - svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100) - - val model = svm.run(testRDD) - - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) - - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } - - test("SVM local random SGD with initial weights") { - val nPoints = 10000 - - // NOTE: Intercept should be small for generating equal 0s and 1s - val A = 0.01 - val B = -1.5 - val C = 1.0 - - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) - - val initialB = -1.0 - val initialC = -1.0 - val initialWeights = Array(initialB,initialC) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - val svm = new SVMWithSGD() - svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100) - - val model = svm.run(testRDD, initialWeights) - - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData,2) - - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } - - test("SVM with invalid labels") { - val nPoints = 10000 - - // NOTE: Intercept should be small for generating equal 0s and 1s - val A = 0.01 - val B = -1.5 - val C = 1.0 - - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) - val testRDD = sc.parallelize(testData, 2) - - val testRDDInvalid = testRDD.map { lp => - if (lp.label == 0.0) { - LabeledPoint(-1.0, lp.features) - } else { - lp - } - } - - intercept[spark.SparkException] { - val model = SVMWithSGD.train(testRDDInvalid, 100) - } - - // Turning off data validation should not throw an exception - val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid) - } -} diff --git a/mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala deleted file mode 100644 index d5d95c8639..0000000000 --- a/mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.clustering - -import scala.util.Random - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.SparkContext._ - -import org.jblas._ - -class KMeansSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - val EPSILON = 1e-4 - - import KMeans.{RANDOM, K_MEANS_PARALLEL} - - def prettyPrint(point: Array[Double]): String = point.mkString("(", ", ", ")") - - def prettyPrint(points: Array[Array[Double]]): String = { - points.map(prettyPrint).mkString("(", "; ", ")") - } - - // L1 distance between two points - def distance1(v1: Array[Double], v2: Array[Double]): Double = { - v1.zip(v2).map{ case (a, b) => math.abs(a-b) }.max - } - - // Assert that two vectors are equal within tolerance EPSILON - def assertEqual(v1: Array[Double], v2: Array[Double]) { - def errorMessage = prettyPrint(v1) + " did not equal " + prettyPrint(v2) - assert(v1.length == v2.length, errorMessage) - assert(distance1(v1, v2) <= EPSILON, errorMessage) - } - - // Assert that two sets of points are equal, within EPSILON tolerance - def assertSetsEqual(set1: Array[Array[Double]], set2: Array[Array[Double]]) { - def errorMessage = prettyPrint(set1) + " did not equal " + prettyPrint(set2) - assert(set1.length == set2.length, errorMessage) - for (v <- set1) { - val closestDistance = set2.map(w => distance1(v, w)).min - if (closestDistance > EPSILON) { - fail(errorMessage) - } - } - for (v <- set2) { - val closestDistance = set1.map(w => distance1(v, w)).min - if (closestDistance > EPSILON) { - fail(errorMessage) - } - } - } - - test("single cluster") { - val data = sc.parallelize(Array( - Array(1.0, 2.0, 6.0), - Array(1.0, 3.0, 0.0), - Array(1.0, 4.0, 6.0) - )) - - // No matter how many runs or iterations we use, we should get one cluster, - // centered at the mean of the points - - var model = KMeans.train(data, k=1, maxIterations=1) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=2) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train( - data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - } - - test("single cluster with big dataset") { - val smallData = Array( - Array(1.0, 2.0, 6.0), - Array(1.0, 3.0, 0.0), - Array(1.0, 4.0, 6.0) - ) - val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4) - - // No matter how many runs or iterations we use, we should get one cluster, - // centered at the mean of the points - - var model = KMeans.train(data, k=1, maxIterations=1) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=2) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - } - - test("k-means|| initialization") { - val points = Array( - Array(1.0, 2.0, 6.0), - Array(1.0, 3.0, 0.0), - Array(1.0, 4.0, 6.0), - Array(1.0, 0.0, 1.0), - Array(1.0, 1.0, 1.0) - ) - val rdd = sc.parallelize(points) - - // K-means|| initialization should place all clusters into distinct centers because - // it will make at least five passes, and it will give non-zero probability to each - // unselected point as long as it hasn't yet selected all of them - - var model = KMeans.train(rdd, k=5, maxIterations=1) - assertSetsEqual(model.clusterCenters, points) - - // Iterations of Lloyd's should not change the answer either - model = KMeans.train(rdd, k=5, maxIterations=10) - assertSetsEqual(model.clusterCenters, points) - - // Neither should more runs - model = KMeans.train(rdd, k=5, maxIterations=10, runs=5) - assertSetsEqual(model.clusterCenters, points) - } -} diff --git a/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala deleted file mode 100644 index 15a60efda6..0000000000 --- a/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.recommendation - -import scala.collection.JavaConversions._ -import scala.util.Random - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.SparkContext._ - -import org.jblas._ - -object ALSSuite { - - def generateRatingsAsJavaList( - users: Int, - products: Int, - features: Int, - samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = { - val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate) - (seqAsJavaList(sampledRatings), trueRatings) - } - - def generateRatings( - users: Int, - products: Int, - features: Int, - samplingRate: Double): (Seq[Rating], DoubleMatrix) = { - val rand = new Random(42) - - // Create a random matrix with uniform values from -1 to 1 - def randomMatrix(m: Int, n: Int) = - new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*) - - val userMatrix = randomMatrix(users, features) - val productMatrix = randomMatrix(features, products) - val trueRatings = userMatrix.mmul(productMatrix) - - val sampledRatings = { - for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate) - yield Rating(u, p, trueRatings.get(u, p)) - } - - (sampledRatings, trueRatings) - } - -} - - -class ALSSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - test("rank-1 matrices") { - testALS(10, 20, 1, 15, 0.7, 0.3) - } - - test("rank-2 matrices") { - testALS(20, 30, 2, 15, 0.7, 0.3) - } - - /** - * Test if we can correctly factorize R = U * P where U and P are of known rank. - * - * @param users number of users - * @param products number of products - * @param features number of features (rank of problem) - * @param iterations number of iterations to run - * @param samplingRate what fraction of the user-product pairs are known - * @param matchThreshold max difference allowed to consider a predicted rating correct - */ - def testALS(users: Int, products: Int, features: Int, iterations: Int, - samplingRate: Double, matchThreshold: Double) - { - val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products, - features, samplingRate) - val model = ALS.train(sc.parallelize(sampledRatings), features, iterations) - - val predictedU = new DoubleMatrix(users, features) - for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) { - predictedU.put(u, i, vec(i)) - } - val predictedP = new DoubleMatrix(products, features) - for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) { - predictedP.put(p, i, vec(i)) - } - val predictedRatings = predictedU.mmul(predictedP.transpose) - - for (u <- 0 until users; p <- 0 until products) { - val prediction = predictedRatings.get(u, p) - val correct = trueRatings.get(u, p) - if (math.abs(prediction - correct) > matchThreshold) { - fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( - u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP)) - } - } - } -} - diff --git a/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala deleted file mode 100644 index 622dbbab7f..0000000000 --- a/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import scala.collection.JavaConversions._ -import scala.util.Random - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.mllib.util.LinearDataGenerator - - -class LassoSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { - val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => - // A prediction is off if the prediction is more than 0.5 away from expected value. - math.abs(prediction - expected.label) > 0.5 - }.size - // At least 80% of the predictions should be on. - assert(numOffPredictions < input.length / 5) - } - - test("Lasso local random SGD") { - val nPoints = 10000 - - val A = 2.0 - val B = -1.5 - val C = 1.0e-2 - - val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) - - val model = ls.run(testRDD) - - val weight0 = model.weights(0) - val weight1 = model.weights(1) - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") - - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) - - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } - - test("Lasso local random SGD with initial weights") { - val nPoints = 10000 - - val A = 2.0 - val B = -1.5 - val C = 1.0e-2 - - val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) - - val initialB = -1.0 - val initialC = -1.0 - val initialWeights = Array(initialB,initialC) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) - - val model = ls.run(testRDD, initialWeights) - - val weight0 = model.weights(0) - val weight1 = model.weights(1) - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") - - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData,2) - - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } -} diff --git a/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala deleted file mode 100644 index acc48a3283..0000000000 --- a/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.SparkContext._ -import spark.mllib.util.LinearDataGenerator - -class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { - val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => - // A prediction is off if the prediction is more than 0.5 away from expected value. - math.abs(prediction - expected.label) > 0.5 - }.size - // At least 80% of the predictions should be on. - assert(numOffPredictions < input.length / 5) - } - - // Test if we can correctly learn Y = 3 + 10*X1 + 10*X2 - test("linear regression") { - val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput( - 3.0, Array(10.0, 10.0), 100, 42), 2).cache() - val linReg = new LinearRegressionWithSGD() - linReg.optimizer.setNumIterations(1000).setStepSize(1.0) - - val model = linReg.run(testRDD) - - assert(model.intercept >= 2.5 && model.intercept <= 3.5) - assert(model.weights.length === 2) - assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) - assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) - - val validationData = LinearDataGenerator.generateLinearInput( - 3.0, Array(10.0, 10.0), 100, 17) - val validationRDD = sc.parallelize(validationData, 2).cache() - - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } -} diff --git a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala deleted file mode 100644 index c482035706..0000000000 --- a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import scala.collection.JavaConversions._ -import scala.util.Random - -import org.jblas.DoubleMatrix -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.SparkContext._ -import spark.mllib.util.LinearDataGenerator - -class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { - predictions.zip(input).map { case (prediction, expected) => - (prediction - expected.label) * (prediction - expected.label) - }.reduceLeft(_ + _) / predictions.size - } - - test("regularization with skewed weights") { - val nexamples = 200 - val nfeatures = 20 - val eps = 10 - - org.jblas.util.Random.seed(42) - // Pick weights as random values distributed uniformly in [-0.5, 0.5] - val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) - // Set first two weights to eps - w.put(0, 0, eps) - w.put(1, 0, eps) - - // Use half of data for training and other half for validation - val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps) - val testData = data.take(nexamples) - val validationData = data.takeRight(nexamples) - - val testRDD = sc.parallelize(testData, 2).cache() - val validationRDD = sc.parallelize(validationData, 2).cache() - - // First run without regularization. - val linearReg = new LinearRegressionWithSGD() - linearReg.optimizer.setNumIterations(200) - .setStepSize(1.0) - - val linearModel = linearReg.run(testRDD) - val linearErr = predictionError( - linearModel.predict(validationRDD.map(_.features)).collect(), validationData) - - val ridgeReg = new RidgeRegressionWithSGD() - ridgeReg.optimizer.setNumIterations(200) - .setRegParam(0.1) - .setStepSize(1.0) - val ridgeModel = ridgeReg.run(testRDD) - val ridgeErr = predictionError( - ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData) - - // Ridge CV-error should be lower than linear regression - assert(ridgeErr < linearErr, - "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") - } -} diff --git a/pom.xml b/pom.xml index e2fd54a966..9230611eae 100644 --- a/pom.xml +++ b/pom.xml @@ -18,22 +18,22 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT pom Spark Project Parent POM - http://spark-project.org/ + http://spark.incubator.apache.org/ - BSD License - https://github.com/mesos/spark/blob/master/LICENSE + Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html repo - scm:git:git@github.com:mesos/spark.git - scm:git:git@github.com:mesos/spark.git + scm:git:git@github.com:apache/incubator-spark.git + scm:git:git@github.com:apache/incubator-spark.git @@ -46,7 +46,7 @@ - github + JIRA https://spark-project.atlassian.net/browse/SPARK diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2e26812671..18e86d2cae 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -74,7 +74,7 @@ object SparkBuild extends Build { core, repl, examples, bagel, streaming, mllib, tools, assemblyProj) ++ maybeYarnRef def sharedSettings = Defaults.defaultSettings ++ Seq( - organization := "org.spark-project", + organization := "org.apache.spark", version := "0.8.0-SNAPSHOT", scalaVersion := "2.9.3", scalacOptions := Seq("-unchecked", "-optimize", "-deprecation"), @@ -103,7 +103,7 @@ object SparkBuild extends Build { //useGpg in Global := true, pomExtra := ( - http://spark-project.org/ + http://spark.incubator.apache.org/ Apache 2.0 License @@ -112,8 +112,8 @@ object SparkBuild extends Build { - scm:git:git@github.com:mesos/spark.git - scm:git:git@github.com:mesos/spark.git + scm:git:git@github.com:apache/incubator-spark.git + scm:git:git@github.com:apache/incubator-spark.git @@ -125,6 +125,10 @@ object SparkBuild extends Build { http://www.cs.berkeley.edu/ + + JIRA + https://spark-project.atlassian.net/browse/SPARK + ), /* diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2803ce90f3..906e9221a1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -114,9 +114,9 @@ class SparkContext(object): self.addPyFile(path) # Create a temporary directory inside spark.local.dir: - local_dir = self._jvm.spark.Utils.getLocalDir() + local_dir = self._jvm.org.apache.spark.Utils.getLocalDir() self._temp_dir = \ - self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath() + self._jvm.org.apache.spark.Utils.createTempDir(local_dir).getAbsolutePath() @property def defaultParallelism(self): diff --git a/python/pyspark/files.py b/python/pyspark/files.py index 89bcbcfe06..57ee14eeb7 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -52,4 +52,4 @@ class SparkFiles(object): return cls._root_directory else: # This will have to change if we support multiple SparkContexts: - return cls._sc._jvm.spark.SparkFiles.getRootDirectory() + return cls._sc._jvm.org.apache.spark.SparkFiles.getRootDirectory() diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3ccf062c86..26fbe0f080 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -53,7 +53,7 @@ def launch_gateway(): # Connect to the gateway gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) # Import the classes used by PySpark - java_import(gateway.jvm, "spark.api.java.*") - java_import(gateway.jvm, "spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.api.java.*") + java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index 919e35f240..6a1b09e8df 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-repl-bin pom Spark Project REPL binary packaging @@ -39,18 +39,18 @@ - org.spark-project + org.apache.spark spark-core ${project.version} - org.spark-project + org.apache.spark spark-bagel ${project.version} runtime - org.spark-project + org.apache.spark spark-repl ${project.version} runtime @@ -109,7 +109,7 @@ hadoop2-yarn - org.spark-project + org.apache.spark spark-yarn ${project.version} diff --git a/repl/pom.xml b/repl/pom.xml index f800664cff..f6276f1895 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-repl jar Spark Project REPL @@ -38,18 +38,18 @@ - org.spark-project + org.apache.spark spark-core ${project.version} - org.spark-project + org.apache.spark spark-bagel ${project.version} runtime - org.spark-project + org.apache.spark spark-mllib ${project.version} runtime @@ -136,7 +136,7 @@ hadoop2-yarn - org.spark-project + org.apache.spark spark-yarn ${project.version} diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala new file mode 100644 index 0000000000..3e171849e3 --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io.{ByteArrayOutputStream, InputStream} +import java.net.{URI, URL, URLClassLoader, URLEncoder} +import java.util.concurrent.{Executors, ExecutorService} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.objectweb.asm._ +import org.objectweb.asm.Opcodes._ + + +/** + * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, + * used to load classes defined by the interpreter when the REPL is used + */ +class ExecutorClassLoader(classUri: String, parent: ClassLoader) +extends ClassLoader(parent) { + val uri = new URI(classUri) + val directory = uri.getPath + + // Hadoop FileSystem object for our URI, if it isn't using HTTP + var fileSystem: FileSystem = { + if (uri.getScheme() == "http") + null + else + FileSystem.get(uri, new Configuration()) + } + + override def findClass(name: String): Class[_] = { + try { + val pathInDirectory = name.replace('.', '/') + ".class" + val inputStream = { + if (fileSystem != null) + fileSystem.open(new Path(directory, pathInDirectory)) + else + new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + } + val bytes = readAndTransformClass(name, inputStream) + inputStream.close() + return defineClass(name, bytes, 0, bytes.length) + } catch { + case e: Exception => throw new ClassNotFoundException(name, e) + } + } + + def readAndTransformClass(name: String, in: InputStream): Array[Byte] = { + if (name.startsWith("line") && name.endsWith("$iw$")) { + // Class seems to be an interpreter "wrapper" object storing a val or var. + // Replace its constructor with a dummy one that does not run the + // initialization code placed there by the REPL. The val or var will + // be initialized later through reflection when it is used in a task. + val cr = new ClassReader(in) + val cw = new ClassWriter( + ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS) + val cleaner = new ConstructorCleaner(name, cw) + cr.accept(cleaner, 0) + return cw.toByteArray + } else { + // Pass the class through unmodified + val bos = new ByteArrayOutputStream + val bytes = new Array[Byte](4096) + var done = false + while (!done) { + val num = in.read(bytes) + if (num >= 0) + bos.write(bytes, 0, num) + else + done = true + } + return bos.toByteArray + } + } + + /** + * URL-encode a string, preserving only slashes + */ + def urlEncode(str: String): String = { + str.split('/').map(part => URLEncoder.encode(part, "UTF-8")).mkString("/") + } +} + +class ConstructorCleaner(className: String, cv: ClassVisitor) +extends ClassVisitor(ASM4, cv) { + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + val mv = cv.visitMethod(access, name, desc, sig, exceptions) + if (name == "" && (access & ACC_STATIC) == 0) { + // This is the constructor, time to clean it; just output some new + // instructions to mv that create the object and set the static MODULE$ + // field in the class to point to it, but do nothing otherwise. + mv.visitCode() + mv.visitVarInsn(ALOAD, 0) // load this + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V") + mv.visitVarInsn(ALOAD, 0) // load this + //val classType = className.replace('.', '/') + //mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") + mv.visitInsn(RETURN) + mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed + mv.visitEnd() + return null + } else { + return mv + } + } +} diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala new file mode 100644 index 0000000000..17e149f8ab --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import scala.collection.mutable.Set + +object Main { + private var _interp: SparkILoop = null + + def interp = _interp + + def interp_=(i: SparkILoop) { _interp = i } + + def main(args: Array[String]) { + _interp = new SparkILoop + _interp.process(args) + } +} diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala new file mode 100644 index 0000000000..d8fb7191b4 --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala @@ -0,0 +1,5 @@ +package scala.tools.nsc + +object SparkHelper { + def explicitParentLoader(settings: Settings) = settings.explicitParentLoader +} diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala new file mode 100644 index 0000000000..193ccb48ee --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -0,0 +1,1008 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2011 LAMP/EPFL + * @author Alexander Spoon + */ + +package org.apache.spark.repl + +import scala.tools.nsc._ +import scala.tools.nsc.interpreter._ + +import Predef.{ println => _, _ } +import java.io.{ BufferedReader, FileReader, PrintWriter } +import scala.sys.process.Process +import session._ +import scala.tools.nsc.interpreter.{ Results => IR } +import scala.tools.util.{ SignalManager, Signallable, Javap } +import scala.annotation.tailrec +import scala.util.control.Exception.{ ignoring } +import scala.collection.mutable.ListBuffer +import scala.concurrent.ops +import util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } +import interpreter._ +import io.{ File, Sources } + +import org.apache.spark.Logging +import org.apache.spark.SparkContext + +/** The Scala interactive shell. It provides a read-eval-print loop + * around the Interpreter class. + * After instantiation, clients should call the main() method. + * + * If no in0 is specified, then input will come from the console, and + * the class will attempt to provide input editing feature such as + * input history. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + * @version 1.2 + */ +class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: Option[String]) + extends AnyRef + with LoopCommands + with Logging +{ + def this(in0: BufferedReader, out: PrintWriter, master: String) = this(Some(in0), out, Some(master)) + def this(in0: BufferedReader, out: PrintWriter) = this(Some(in0), out, None) + def this() = this(None, new PrintWriter(Console.out, true), None) + + var in: InteractiveReader = _ // the input stream from which commands come + var settings: Settings = _ + var intp: SparkIMain = _ + + /* + lazy val power = { + val g = intp.global + Power[g.type](this, g) + } + */ + + // TODO + // object opt extends AestheticSettings + // + @deprecated("Use `intp` instead.", "2.9.0") + def interpreter = intp + + @deprecated("Use `intp` instead.", "2.9.0") + def interpreter_= (i: SparkIMain): Unit = intp = i + + def history = in.history + + /** The context class loader at the time this object was created */ + protected val originalClassLoader = Thread.currentThread.getContextClassLoader + + // Install a signal handler so we can be prodded. + private val signallable = + /*if (isReplDebug) Signallable("Dump repl state.")(dumpCommand()) + else*/ null + + // classpath entries added via :cp + var addedClasspath: String = "" + + /** A reverse list of commands to replay if the user requests a :replay */ + var replayCommandStack: List[String] = Nil + + /** A list of commands to replay if the user requests a :replay */ + def replayCommands = replayCommandStack.reverse + + /** Record a command for replay should the user request a :replay */ + def addReplay(cmd: String) = replayCommandStack ::= cmd + + /** Try to install sigint handler: ignore failure. Signal handler + * will interrupt current line execution if any is in progress. + * + * Attempting to protect the repl from accidental exit, we only honor + * a single ctrl-C if the current buffer is empty: otherwise we look + * for a second one within a short time. + */ + private def installSigIntHandler() { + def onExit() { + Console.println("") // avoiding "shell prompt in middle of line" syndrome + sys.exit(1) + } + ignoring(classOf[Exception]) { + SignalManager("INT") = { + if (intp == null) + onExit() + else if (intp.lineManager.running) + intp.lineManager.cancel() + else if (in.currentLine != "") { + // non-empty buffer, so make them hit ctrl-C a second time + SignalManager("INT") = onExit() + io.timer(5)(installSigIntHandler()) // and restore original handler if they don't + } + else onExit() + } + } + } + + /** Close the interpreter and set the var to null. */ + def closeInterpreter() { + if (intp ne null) { + intp.close + intp = null + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } + + class SparkILoopInterpreter extends SparkIMain(settings, out) { + override lazy val formatting = new Formatting { + def prompt = SparkILoop.this.prompt + } + override protected def createLineManager() = new Line.Manager { + override def onRunaway(line: Line[_]): Unit = { + val template = """ + |// She's gone rogue, captain! Have to take her out! + |// Calling Thread.stop on runaway %s with offending code: + |// scala> %s""".stripMargin + + echo(template.format(line.thread, line.code)) + // XXX no way to suppress the deprecation warning + line.thread.stop() + in.redrawLine() + } + } + override protected def parentClassLoader = { + SparkHelper.explicitParentLoader(settings).getOrElse( classOf[SparkILoop].getClassLoader ) + } + } + + /** Create a new interpreter. */ + def createInterpreter() { + if (addedClasspath != "") + settings.classpath append addedClasspath + + intp = new SparkILoopInterpreter + intp.setContextClassLoader() + installSigIntHandler() + } + + /** print a friendly help message */ + def helpCommand(line: String): Result = { + if (line == "") helpSummary() + else uniqueCommand(line) match { + case Some(lc) => echo("\n" + lc.longHelp) + case _ => ambiguousError(line) + } + } + private def helpSummary() = { + val usageWidth = commands map (_.usageMsg.length) max + val formatStr = "%-" + usageWidth + "s %s %s" + + echo("All commands can be abbreviated, e.g. :he instead of :help.") + echo("Those marked with a * have more detailed help, e.g. :help imports.\n") + + commands foreach { cmd => + val star = if (cmd.hasLongHelp) "*" else " " + echo(formatStr.format(cmd.usageMsg, star, cmd.help)) + } + } + private def ambiguousError(cmd: String): Result = { + matchingCommands(cmd) match { + case Nil => echo(cmd + ": no such command. Type :help for help.") + case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") + } + Result(true, None) + } + private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) + private def uniqueCommand(cmd: String): Option[LoopCommand] = { + // this lets us add commands willy-nilly and only requires enough command to disambiguate + matchingCommands(cmd) match { + case List(x) => Some(x) + // exact match OK even if otherwise appears ambiguous + case xs => xs find (_.name == cmd) + } + } + + /** Print a welcome message */ + def printWelcome() { + echo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version 0.8.0 + /_/ +""") + import Properties._ + val welcomeMsg = "Using Scala %s (%s, Java %s)".format( + versionString, javaVmName, javaVersion) + echo(welcomeMsg) + } + + /** Show the history */ + lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { + override def usage = "[num]" + def defaultLines = 20 + + def apply(line: String): Result = { + if (history eq NoHistory) + return "No history available." + + val xs = words(line) + val current = history.index + val count = try xs.head.toInt catch { case _: Exception => defaultLines } + val lines = history.asStrings takeRight count + val offset = current - lines.size + 1 + + for ((line, index) <- lines.zipWithIndex) + echo("%3d %s".format(index + offset, line)) + } + } + + private def echo(msg: String) = { + out println msg + out.flush() + } + private def echoNoNL(msg: String) = { + out print msg + out.flush() + } + + /** Search the history */ + def searchHistory(_cmdline: String) { + val cmdline = _cmdline.toLowerCase + val offset = history.index - history.size + 1 + + for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) + echo("%d %s".format(index + offset, line)) + } + + private var currentPrompt = Properties.shellPromptString + def setPrompt(prompt: String) = currentPrompt = prompt + /** Prompt to print when awaiting input */ + def prompt = currentPrompt + + import LoopCommand.{ cmd, nullary } + + /** Standard commands **/ + lazy val standardCommands = List( + cmd("cp", "", "add a jar or directory to the classpath", addClasspath), + cmd("help", "[command]", "print this summary or command-specific help", helpCommand), + historyCommand, + cmd("h?", "", "search the history", searchHistory), + cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), + cmd("implicits", "[-v]", "show the implicits in scope", implicitsCommand), + cmd("javap", "", "disassemble a file or class name", javapCommand), + nullary("keybindings", "show how ctrl-[A-Z] and other keys are bound", keybindingsCommand), + cmd("load", "", "load and interpret a Scala file", loadCommand), + nullary("paste", "enter paste mode: all input up to ctrl-D compiled together", pasteCommand), + //nullary("power", "enable power user mode", powerCmd), + nullary("quit", "exit the interpreter", () => Result(false, None)), + nullary("replay", "reset execution and replay all previous commands", replay), + shCommand, + nullary("silent", "disable/enable automatic printing of results", verbosity), + cmd("type", "", "display the type of an expression without evaluating it", typeCommand) + ) + + /** Power user commands */ + lazy val powerCommands: List[LoopCommand] = List( + //nullary("dump", "displays a view of the interpreter's internal state", dumpCommand), + //cmd("phase", "", "set the implicit phase for power commands", phaseCommand), + cmd("wrap", "", "name of method to wrap around each repl line", wrapCommand) withLongHelp (""" + |:wrap + |:wrap clear + |:wrap + | + |Installs a wrapper around each line entered into the repl. + |Currently it must be the simple name of an existing method + |with the specific signature shown in the following example. + | + |def timed[T](body: => T): T = { + | val start = System.nanoTime + | try body + | finally println((System.nanoTime - start) + " nanos elapsed.") + |} + |:wrap timed + | + |If given no argument, :wrap names the wrapper installed. + |An argument of clear will remove the wrapper if any is active. + |Note that wrappers do not compose (a new one replaces the old + |one) and also that the :phase command uses the same machinery, + |so setting :wrap will clear any :phase setting. + """.stripMargin.trim) + ) + + /* + private def dumpCommand(): Result = { + echo("" + power) + history.asStrings takeRight 30 foreach echo + in.redrawLine() + } + */ + + private val typeTransforms = List( + "scala.collection.immutable." -> "immutable.", + "scala.collection.mutable." -> "mutable.", + "scala.collection.generic." -> "generic.", + "java.lang." -> "jl.", + "scala.runtime." -> "runtime." + ) + + private def importsCommand(line: String): Result = { + val tokens = words(line) + val handlers = intp.languageWildcardHandlers ++ intp.importHandlers + val isVerbose = tokens contains "-v" + + handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { + case (handler, idx) => + val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) + val imps = handler.implicitSymbols + val found = tokens filter (handler importsSymbolNamed _) + val typeMsg = if (types.isEmpty) "" else types.size + " types" + val termMsg = if (terms.isEmpty) "" else terms.size + " terms" + val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" + val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") + val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") + + intp.reporter.printMessage("%2d) %-30s %s%s".format( + idx + 1, + handler.importString, + statsMsg, + foundMsg + )) + } + } + + private def implicitsCommand(line: String): Result = { + val intp = SparkILoop.this.intp + import intp._ + import global.Symbol + + def p(x: Any) = intp.reporter.printMessage("" + x) + + // If an argument is given, only show a source with that + // in its name somewhere. + val args = line split "\\s+" + val filtered = intp.implicitSymbolsBySource filter { + case (source, syms) => + (args contains "-v") || { + if (line == "") (source.fullName.toString != "scala.Predef") + else (args exists (source.name.toString contains _)) + } + } + + if (filtered.isEmpty) + return "No implicits have been imported other than those in Predef." + + filtered foreach { + case (source, syms) => + p("/* " + syms.size + " implicit members imported from " + source.fullName + " */") + + // This groups the members by where the symbol is defined + val byOwner = syms groupBy (_.owner) + val sortedOwners = byOwner.toList sortBy { case (owner, _) => intp.afterTyper(source.info.baseClasses indexOf owner) } + + sortedOwners foreach { + case (owner, members) => + // Within each owner, we cluster results based on the final result type + // if there are more than a couple, and sort each cluster based on name. + // This is really just trying to make the 100 or so implicits imported + // by default into something readable. + val memberGroups: List[List[Symbol]] = { + val groups = members groupBy (_.tpe.finalResultType) toList + val (big, small) = groups partition (_._2.size > 3) + val xss = ( + (big sortBy (_._1.toString) map (_._2)) :+ + (small flatMap (_._2)) + ) + + xss map (xs => xs sortBy (_.name.toString)) + } + + val ownerMessage = if (owner == source) " defined in " else " inherited from " + p(" /* " + members.size + ownerMessage + owner.fullName + " */") + + memberGroups foreach { group => + group foreach (s => p(" " + intp.symbolDefString(s))) + p("") + } + } + p("") + } + } + + protected def newJavap() = new Javap(intp.classLoader, new SparkIMain.ReplStrippingWriter(intp)) { + override def tryClass(path: String): Array[Byte] = { + // Look for Foo first, then Foo$, but if Foo$ is given explicitly, + // we have to drop the $ to find object Foo, then tack it back onto + // the end of the flattened name. + def className = intp flatName path + def moduleName = (intp flatName path.stripSuffix("$")) + "$" + + val bytes = super.tryClass(className) + if (bytes.nonEmpty) bytes + else super.tryClass(moduleName) + } + } + private lazy val javap = + try newJavap() + catch { case _: Exception => null } + + private def typeCommand(line: String): Result = { + intp.typeOfExpression(line) match { + case Some(tp) => tp.toString + case _ => "Failed to determine type." + } + } + + private def javapCommand(line: String): Result = { + if (javap == null) + return ":javap unavailable on this platform." + if (line == "") + return ":javap [-lcsvp] [path1 path2 ...]" + + javap(words(line)) foreach { res => + if (res.isError) return "Failed: " + res.value + else res.show() + } + } + private def keybindingsCommand(): Result = { + if (in.keyBindings.isEmpty) "Key bindings unavailable." + else { + echo("Reading jline properties for default key bindings.") + echo("Accuracy not guaranteed: treat this as a guideline only.\n") + in.keyBindings foreach (x => echo ("" + x)) + } + } + private def wrapCommand(line: String): Result = { + def failMsg = "Argument to :wrap must be the name of a method with signature [T](=> T): T" + val intp = SparkILoop.this.intp + val g: intp.global.type = intp.global + import g._ + + words(line) match { + case Nil => + intp.executionWrapper match { + case "" => "No execution wrapper is set." + case s => "Current execution wrapper: " + s + } + case "clear" :: Nil => + intp.executionWrapper match { + case "" => "No execution wrapper is set." + case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper." + } + case wrapper :: Nil => + intp.typeOfExpression(wrapper) match { + case Some(PolyType(List(targ), MethodType(List(arg), restpe))) => + intp setExecutionWrapper intp.pathToTerm(wrapper) + "Set wrapper to '" + wrapper + "'" + case Some(x) => + failMsg + "\nFound: " + x + case _ => + failMsg + "\nFound: " + } + case _ => failMsg + } + } + + private def pathToPhaseWrapper = intp.pathToTerm("$r") + ".phased.atCurrent" + /* + private def phaseCommand(name: String): Result = { + // This line crashes us in TreeGen: + // + // if (intp.power.phased set name) "..." + // + // Exception in thread "main" java.lang.AssertionError: assertion failed: ._7.type + // at scala.Predef$.assert(Predef.scala:99) + // at scala.tools.nsc.ast.TreeGen.mkAttributedQualifier(TreeGen.scala:69) + // at scala.tools.nsc.ast.TreeGen.mkAttributedQualifier(TreeGen.scala:44) + // at scala.tools.nsc.ast.TreeGen.mkAttributedRef(TreeGen.scala:101) + // at scala.tools.nsc.ast.TreeGen.mkAttributedStableRef(TreeGen.scala:143) + // + // But it works like so, type annotated. + val phased: Phased = power.phased + import phased.NoPhaseName + + if (name == "clear") { + phased.set(NoPhaseName) + intp.clearExecutionWrapper() + "Cleared active phase." + } + else if (name == "") phased.get match { + case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" + case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) + } + else { + val what = phased.parse(name) + if (what.isEmpty || !phased.set(what)) + "'" + name + "' does not appear to represent a valid phase." + else { + intp.setExecutionWrapper(pathToPhaseWrapper) + val activeMessage = + if (what.toString.length == name.length) "" + what + else "%s (%s)".format(what, name) + + "Active phase is now: " + activeMessage + } + } + } + */ + + /** Available commands */ + def commands: List[LoopCommand] = standardCommands /* ++ ( + if (isReplPower) powerCommands else Nil + )*/ + + val replayQuestionMessage = + """|The repl compiler has crashed spectacularly. Shall I replay your + |session? I can re-run all lines except the last one. + |[y/n] + """.trim.stripMargin + + private val crashRecovery: PartialFunction[Throwable, Unit] = { + case ex: Throwable => + if (settings.YrichExes.value) { + val sources = implicitly[Sources] + echo("\n" + ex.getMessage) + echo( + if (isReplDebug) "[searching " + sources.path + " for exception contexts...]" + else "[searching for exception contexts...]" + ) + echo(Exceptional(ex).force().context()) + } + else { + echo(util.stackTraceString(ex)) + } + ex match { + case _: NoSuchMethodError | _: NoClassDefFoundError => + echo("Unrecoverable error.") + throw ex + case _ => + def fn(): Boolean = in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + if (fn()) replay() + else echo("\nAbandoning crashed session.") + } + } + + /** The main read-eval-print loop for the repl. It calls + * command() for each line of input, and stops when + * command() returns false. + */ + def loop() { + def readOneLine() = { + out.flush() + in readLine prompt + } + // return false if repl should exit + def processLine(line: String): Boolean = + if (line eq null) false // assume null means EOF + else command(line) match { + case Result(false, _) => false + case Result(_, Some(finalLine)) => addReplay(finalLine) ; true + case _ => true + } + + while (true) { + try if (!processLine(readOneLine)) return + catch crashRecovery + } + } + + /** interpret all lines from a specified file */ + def interpretAllFrom(file: File) { + val oldIn = in + val oldReplay = replayCommandStack + + try file applyReader { reader => + in = SimpleReader(reader, out, false) + echo("Loading " + file + "...") + loop() + } + finally { + in = oldIn + replayCommandStack = oldReplay + } + } + + /** create a new interpreter and replay all commands so far */ + def replay() { + closeInterpreter() + createInterpreter() + for (cmd <- replayCommands) { + echo("Replaying: " + cmd) // flush because maybe cmd will have its own output + command(cmd) + echo("") + } + } + + /** fork a shell and run a command */ + lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { + override def usage = "" + def apply(line: String): Result = line match { + case "" => showUsage() + case _ => + val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")" + intp interpret toRun + () + } + } + + def withFile(filename: String)(action: File => Unit) { + val f = File(filename) + + if (f.exists) action(f) + else echo("That file does not exist") + } + + def loadCommand(arg: String) = { + var shouldReplay: Option[String] = None + withFile(arg)(f => { + interpretAllFrom(f) + shouldReplay = Some(":load " + arg) + }) + Result(true, shouldReplay) + } + + def addClasspath(arg: String): Unit = { + val f = File(arg).normalize + if (f.exists) { + addedClasspath = ClassPath.join(addedClasspath, f.path) + val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) + echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) + replay() + } + else echo("The path '" + f + "' doesn't seem to exist.") + } + + def powerCmd(): Result = { + if (isReplPower) "Already in power mode." + else enablePowerMode() + } + def enablePowerMode() = { + //replProps.power setValue true + //power.unleash() + //echo(power.banner) + } + + def verbosity() = { + val old = intp.printResults + intp.printResults = !old + echo("Switched " + (if (old) "off" else "on") + " result printing.") + } + + /** Run one command submitted by the user. Two values are returned: + * (1) whether to keep running, (2) the line to record for replay, + * if any. */ + def command(line: String): Result = { + if (line startsWith ":") { + val cmd = line.tail takeWhile (x => !x.isWhitespace) + uniqueCommand(cmd) match { + case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) + case _ => ambiguousError(cmd) + } + } + else if (intp.global == null) Result(false, None) // Notice failure to create compiler + else Result(true, interpretStartingWith(line)) + } + + private def readWhile(cond: String => Boolean) = { + Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) + } + + def pasteCommand(): Result = { + echo("// Entering paste mode (ctrl-D to finish)\n") + val code = readWhile(_ => true) mkString "\n" + echo("\n// Exiting paste mode, now interpreting.\n") + intp interpret code + () + } + + private object paste extends Pasted { + val ContinueString = " | " + val PromptString = "scala> " + + def interpret(line: String): Unit = { + echo(line.trim) + intp interpret line + echo("") + } + + def transcript(start: String) = { + // Printing this message doesn't work very well because it's buried in the + // transcript they just pasted. Todo: a short timer goes off when + // lines stop coming which tells them to hit ctrl-D. + // + // echo("// Detected repl transcript paste: ctrl-D to finish.") + apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) + } + } + import paste.{ ContinueString, PromptString } + + /** Interpret expressions starting with the first line. + * Read lines until a complete compilation unit is available + * or until a syntax error has been seen. If a full unit is + * read, go ahead and interpret it. Return the full string + * to be recorded for replay, if any. + */ + def interpretStartingWith(code: String): Option[String] = { + // signal completion non-completion input has been received + in.completion.resetVerbosity() + + def reallyInterpret = { + val reallyResult = intp.interpret(code) + (reallyResult, reallyResult match { + case IR.Error => None + case IR.Success => Some(code) + case IR.Incomplete => + if (in.interactive && code.endsWith("\n\n")) { + echo("You typed two blank lines. Starting a new command.") + None + } + else in.readLine(ContinueString) match { + case null => + // we know compilation is going to fail since we're at EOF and the + // parser thinks the input is still incomplete, but since this is + // a file being read non-interactively we want to fail. So we send + // it straight to the compiler for the nice error message. + intp.compileString(code) + None + + case line => interpretStartingWith(code + "\n" + line) + } + }) + } + + /** Here we place ourselves between the user and the interpreter and examine + * the input they are ostensibly submitting. We intervene in several cases: + * + * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. + * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation + * on the previous result. + * 3) If the Completion object's execute returns Some(_), we inject that value + * and avoid the interpreter, as it's likely not valid scala code. + */ + if (code == "") None + else if (!paste.running && code.trim.startsWith(PromptString)) { + paste.transcript(code) + None + } + else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { + interpretStartingWith(intp.mostRecentVar + code) + } + else { + def runCompletion = in.completion execute code map (intp bindValue _) + /** Due to my accidentally letting file completion execution sneak ahead + * of actual parsing this now operates in such a way that the scala + * interpretation always wins. However to avoid losing useful file + * completion I let it fail and then check the others. So if you + * type /tmp it will echo a failure and then give you a Directory object. + * It's not pretty: maybe I'll implement the silence bits I need to avoid + * echoing the failure. + */ + if (intp isParseable code) { + val (code, result) = reallyInterpret + //if (power != null && code == IR.Error) + // runCompletion + + result + } + else runCompletion match { + case Some(_) => None // completion hit: avoid the latent error + case _ => reallyInterpret._2 // trigger the latent error + } + } + } + + // runs :load `file` on any files passed via -i + def loadFiles(settings: Settings) = settings match { + case settings: GenericRunnerSettings => + for (filename <- settings.loadfiles.value) { + val cmd = ":load " + filename + command(cmd) + addReplay(cmd) + echo("") + } + case _ => + } + + /** Tries to create a JLineReader, falling back to SimpleReader: + * unless settings or properties are such that it should start + * with SimpleReader. + */ + def chooseReader(settings: Settings): InteractiveReader = { + if (settings.Xnojline.value || Properties.isEmacsShell) + SimpleReader() + else try SparkJLineReader( + if (settings.noCompletion.value) NoCompletion + else new SparkJLineCompletion(intp) + ) + catch { + case ex @ (_: Exception | _: NoClassDefFoundError) => + echo("Failed to created SparkJLineReader: " + ex + "\nFalling back to SimpleReader.") + SimpleReader() + } + } + + def initializeSpark() { + intp.beQuietDuring { + command(""" + org.apache.spark.repl.Main.interp.out.println("Creating SparkContext..."); + org.apache.spark.repl.Main.interp.out.flush(); + @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext(); + org.apache.spark.repl.Main.interp.out.println("Spark context available as sc."); + org.apache.spark.repl.Main.interp.out.flush(); + """) + command("import org.apache.spark.SparkContext._") + } + echo("Type in expressions to have them evaluated.") + echo("Type :help for more information.") + } + + var sparkContext: SparkContext = null + + def createSparkContext(): SparkContext = { + val uri = System.getenv("SPARK_EXECUTOR_URI") + if (uri != null) { + System.setProperty("spark.executor.uri", uri) + } + val master = this.master match { + case Some(m) => m + case None => { + val prop = System.getenv("MASTER") + if (prop != null) prop else "local" + } + } + val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')) + .getOrElse(new Array[String](0)) + .map(new java.io.File(_).getAbsolutePath) + sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) + sparkContext + } + + def process(settings: Settings): Boolean = { + // Ensure logging is initialized before any Spark threads try to use logs + // (because SLF4J initialization is not thread safe) + initLogging() + + printWelcome() + echo("Initializing interpreter...") + + // Add JARS specified in Spark's ADD_JARS variable to classpath + val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0)) + jars.foreach(settings.classpath.append(_)) + + this.settings = settings + createInterpreter() + + // sets in to some kind of reader depending on environmental cues + in = in0 match { + case Some(reader) => SimpleReader(reader, out, true) + case None => chooseReader(settings) + } + + loadFiles(settings) + // it is broken on startup; go ahead and exit + if (intp.reporter.hasErrors) + return false + + try { + // this is about the illusion of snappiness. We call initialize() + // which spins off a separate thread, then print the prompt and try + // our best to look ready. Ideally the user will spend a + // couple seconds saying "wow, it starts so fast!" and by the time + // they type a command the compiler is ready to roll. + intp.initialize() + initializeSpark() + if (isReplPower) { + echo("Starting in power mode, one moment...\n") + enablePowerMode() + } + loop() + } + finally closeInterpreter() + true + } + + /** process command-line arguments and do as they request */ + def process(args: Array[String]): Boolean = { + val command = new CommandLine(args.toList, msg => echo("scala: " + msg)) + def neededHelp(): String = + (if (command.settings.help.value) command.usageMsg + "\n" else "") + + (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "") + + // if they asked for no help and command is valid, we call the real main + neededHelp() match { + case "" => command.ok && process(command.settings) + case help => echoNoNL(help) ; true + } + } + + @deprecated("Use `process` instead", "2.9.0") + def main(args: Array[String]): Unit = { + if (isReplDebug) + System.out.println(new java.util.Date) + + process(args) + } + @deprecated("Use `process` instead", "2.9.0") + def main(settings: Settings): Unit = process(settings) +} + +object SparkILoop { + implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp + private def echo(msg: String) = Console println msg + + // Designed primarily for use by test code: take a String with a + // bunch of code, and prints out a transcript of what it would look + // like if you'd just typed it into the repl. + def runForTranscript(code: String, settings: Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val output = new PrintWriter(new OutputStreamWriter(ostream), true) { + override def write(str: String) = { + // completely skip continuation lines + if (str forall (ch => ch.isWhitespace || ch == '|')) () + // print a newline on empty scala prompts + else if ((str contains '\n') && (str.trim == "scala> ")) super.write("\n") + else super.write(str) + } + } + val input = new BufferedReader(new StringReader(code)) { + override def readLine(): String = { + val s = super.readLine() + // helping out by printing the line being interpreted. + if (s != null) + output.println(s) + s + } + } + val repl = new SparkILoop(input, output) + if (settings.classpath.isDefault) + settings.classpath.value = sys.props("java.class.path") + + repl process settings + } + } + } + + /** Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ + def run(code: String, sets: Settings = new Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val input = new BufferedReader(new StringReader(code)) + val output = new PrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) + + if (sets.classpath.isDefault) + sets.classpath.value = sys.props("java.class.path") + + repl process sets + } + } + } + def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) + + // provide the enclosing type T + // in order to set up the interpreter's classpath and parent class loader properly + def breakIf[T: Manifest](assertion: => Boolean, args: NamedParam*): Unit = + if (assertion) break[T](args.toList) + + // start a repl, binding supplied args + def break[T: Manifest](args: List[NamedParam]): Unit = { + val msg = if (args.isEmpty) "" else " Binding " + args.size + " value%s.".format( + if (args.size == 1) "" else "s" + ) + echo("Debug repl starting." + msg) + val repl = new SparkILoop { + override def prompt = "\ndebug> " + } + repl.settings = new Settings(echo) + repl.settings.embeddedDefaults[T] + repl.createInterpreter() + repl.in = SparkJLineReader(repl) + + // rebind exit so people don't accidentally call sys.exit by way of predef + repl.quietRun("""def exit = println("Type :quit to resume program execution.")""") + args foreach (p => repl.bind(p.name, p.tpe, p.value)) + repl.loop() + + echo("\nDebug repl exiting.") + repl.closeInterpreter() + } +} diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala new file mode 100644 index 0000000000..7e244e48a2 --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -0,0 +1,1160 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2011 LAMP/EPFL + * @author Martin Odersky + */ + +package org.apache.spark.repl + +import scala.tools.nsc._ +import scala.tools.nsc.interpreter._ + +import Predef.{ println => _, _ } +import java.io.{ PrintWriter } +import java.lang.reflect +import java.net.URL +import util.{ Set => _, _ } +import io.{ AbstractFile, PlainFile, VirtualDirectory } +import reporters.{ ConsoleReporter, Reporter } +import symtab.{ Flags, Names } +import scala.tools.nsc.interpreter.{ Results => IR } +import scala.tools.util.PathResolver +import scala.tools.nsc.util.{ ScalaClassLoader, Exceptional } +import ScalaClassLoader.URLClassLoader +import Exceptional.unwrap +import scala.collection.{ mutable, immutable } +import scala.PartialFunction.{ cond, condOpt } +import scala.util.control.Exception.{ ultimately } +import scala.reflect.NameTransformer +import SparkIMain._ + +import org.apache.spark.HttpServer +import org.apache.spark.Utils +import org.apache.spark.SparkEnv + +/** An interpreter for Scala code. + * + * The main public entry points are compile(), interpret(), and bind(). + * The compile() method loads a complete Scala file. The interpret() method + * executes one line of Scala code at the request of the user. The bind() + * method binds an object to a variable that can then be used by later + * interpreted code. + * + * The overall approach is based on compiling the requested code and then + * using a Java classloader and Java reflection to run the code + * and access its results. + * + * In more detail, a single compiler instance is used + * to accumulate all successfully compiled or interpreted Scala code. To + * "interpret" a line of code, the compiler generates a fresh object that + * includes the line of code and which has public member(s) to export + * all variables defined by that code. To extract the result of an + * interpreted line to show the user, a second "result object" is created + * which imports the variables exported by the above object and then + * exports a single member named "$export". To accomodate user expressions + * that read from variables or methods defined in previous statements, "import" + * statements are used. + * + * This interpreter shares the strengths and weaknesses of using the + * full compiler-to-Java. The main strength is that interpreted code + * behaves exactly as does compiled code, including running at full speed. + * The main weakness is that redefining classes and methods is not handled + * properly, because rebinding at the Java level is technically difficult. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + */ +class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends SparkImports { + imain => + + /** construct an interpreter that reports to Console */ + def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this() = this(new Settings()) + + /** whether to print out result lines */ + var printResults: Boolean = true + + /** whether to print errors */ + var totalSilence: Boolean = false + + private val RESULT_OBJECT_PREFIX = "RequestResult$" + + lazy val formatting: Formatting = new Formatting { + val prompt = Properties.shellPromptString + } + import formatting._ + + val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + + /** Local directory to save .class files too */ + val outputDir = { + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = System.getProperty("spark.repl.classdir", tmp) + Utils.createTempDir(rootDir) + } + if (SPARK_DEBUG_REPL) { + echo("Output directory: " + outputDir) + } + + /** Scala compiler virtual directory for outputDir */ + val virtualDirectory = new PlainFile(outputDir) + + /** Jetty server that will serve our classes to worker nodes */ + val classServer = new HttpServer(outputDir) + + // Start the classServer and store its URI in a spark system property + // (which will be passed to executors so that they can connect to it) + classServer.start() + System.setProperty("spark.repl.class.uri", classServer.uri) + if (SPARK_DEBUG_REPL) { + echo("Class server started, URI = " + classServer.uri) + } + + /* + // directory to save .class files to + val virtualDirectory = new VirtualDirectory("(memory)", None) { + private def pp(root: io.AbstractFile, indentLevel: Int) { + val spaces = " " * indentLevel + out.println(spaces + root.name) + if (root.isDirectory) + root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1)) + } + // print the contents hierarchically + def show() = pp(this, 0) + } + */ + + /** reporter */ + lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) + import reporter.{ printMessage, withoutTruncating } + + // not sure if we have some motivation to print directly to console + private def echo(msg: String) { Console println msg } + + // protected def defaultImports: List[String] = List("_root_.scala.sys.exit") + + /** We're going to go to some trouble to initialize the compiler asynchronously. + * It's critical that nothing call into it until it's been initialized or we will + * run into unrecoverable issues, but the perceived repl startup time goes + * through the roof if we wait for it. So we initialize it with a future and + * use a lazy val to ensure that any attempt to use the compiler object waits + * on the future. + */ + private val _compiler: Global = newCompiler(settings, reporter) + private var _initializeComplete = false + def isInitializeComplete = _initializeComplete + + private def _initialize(): Boolean = { + val source = """ + |class $repl_$init { + | List(1) map (_ + 1) + |} + |""".stripMargin + + val result = try { + new _compiler.Run() compileSources List(new BatchSourceFile("", source)) + if (isReplDebug || settings.debug.value) { + // Can't use printMessage here, it deadlocks + Console.println("Repl compiler initialized.") + } + // addImports(defaultImports: _*) + true + } + catch { + case x: AbstractMethodError => + printMessage(""" + |Failed to initialize compiler: abstract method error. + |This is most often remedied by a full clean and recompile. + |""".stripMargin + ) + x.printStackTrace() + false + case x: MissingRequirementError => printMessage(""" + |Failed to initialize compiler: %s not found. + |** Note that as of 2.8 scala does not assume use of the java classpath. + |** For the old behavior pass -usejavacp to scala, or if using a Settings + |** object programatically, settings.usejavacp.value = true.""".stripMargin.format(x.req) + ) + false + } + + try result + finally _initializeComplete = result + } + + // set up initialization future + private var _isInitialized: () => Boolean = null + def initialize() = synchronized { + if (_isInitialized == null) + _isInitialized = scala.concurrent.ops future _initialize() + } + + /** the public, go through the future compiler */ + lazy val global: Global = { + initialize() + + // blocks until it is ; false means catastrophic failure + if (_isInitialized()) _compiler + else null + } + @deprecated("Use `global` for access to the compiler instance.", "2.9.0") + lazy val compiler: global.type = global + + import global._ + + object naming extends { + val global: imain.global.type = imain.global + } with Naming { + // make sure we don't overwrite their unwisely named res3 etc. + override def freshUserVarName(): String = { + val name = super.freshUserVarName() + if (definedNameMap contains name) freshUserVarName() + else name + } + } + import naming._ + + // object dossiers extends { + // val intp: imain.type = imain + // } with Dossiers { } + // import dossiers._ + + lazy val memberHandlers = new { + val intp: imain.type = imain + } with SparkMemberHandlers + import memberHandlers._ + + def atPickler[T](op: => T): T = atPhase(currentRun.picklerPhase)(op) + def afterTyper[T](op: => T): T = atPhase(currentRun.typerPhase.next)(op) + + /** Temporarily be quiet */ + def beQuietDuring[T](operation: => T): T = { + val wasPrinting = printResults + ultimately(printResults = wasPrinting) { + if (isReplDebug) echo(">> beQuietDuring") + else printResults = false + + operation + } + } + def beSilentDuring[T](operation: => T): T = { + val saved = totalSilence + totalSilence = true + try operation + finally totalSilence = saved + } + + def quietRun[T](code: String) = beQuietDuring(interpret(code)) + + /** whether to bind the lastException variable */ + private var bindLastException = true + + /** A string representing code to be wrapped around all lines. */ + private var _executionWrapper: String = "" + def executionWrapper = _executionWrapper + def setExecutionWrapper(code: String) = _executionWrapper = code + def clearExecutionWrapper() = _executionWrapper = "" + + /** Temporarily stop binding lastException */ + def withoutBindingLastException[T](operation: => T): T = { + val wasBinding = bindLastException + ultimately(bindLastException = wasBinding) { + bindLastException = false + operation + } + } + + protected def createLineManager(): Line.Manager = new Line.Manager + lazy val lineManager = createLineManager() + + /** interpreter settings */ + lazy val isettings = new SparkISettings(this) + + /** Instantiate a compiler. Subclasses can override this to + * change the compiler class used by this interpreter. */ + protected def newCompiler(settings: Settings, reporter: Reporter) = { + settings.outputDirs setSingleOutput virtualDirectory + settings.exposeEmptyPackage.value = true + new Global(settings, reporter) + } + + /** the compiler's classpath, as URL's */ + lazy val compilerClasspath: List[URL] = new PathResolver(settings) asURLs + + /* A single class loader is used for all commands interpreted by this Interpreter. + It would also be possible to create a new class loader for each command + to interpret. The advantages of the current approach are: + + - Expressions are only evaluated one time. This is especially + significant for I/O, e.g. "val x = Console.readLine" + + The main disadvantage is: + + - Objects, classes, and methods cannot be rebound. Instead, definitions + shadow the old ones, and old code objects refer to the old + definitions. + */ + private var _classLoader: AbstractFileClassLoader = null + def resetClassLoader() = _classLoader = makeClassLoader() + def classLoader: AbstractFileClassLoader = { + if (_classLoader == null) + resetClassLoader() + + _classLoader + } + private def makeClassLoader(): AbstractFileClassLoader = { + val parent = + if (parentClassLoader == null) ScalaClassLoader fromURLs compilerClasspath + else new URLClassLoader(compilerClasspath, parentClassLoader) + + new AbstractFileClassLoader(virtualDirectory, parent) { + /** Overridden here to try translating a simple name to the generated + * class name if the original attempt fails. This method is used by + * getResourceAsStream as well as findClass. + */ + override protected def findAbstractFile(name: String): AbstractFile = { + super.findAbstractFile(name) match { + // deadlocks on startup if we try to translate names too early + case null if isInitializeComplete => generatedName(name) map (x => super.findAbstractFile(x)) orNull + case file => file + } + } + } + } + private def loadByName(s: String): JClass = + (classLoader tryToInitializeClass s) getOrElse sys.error("Failed to load expected class: '" + s + "'") + + protected def parentClassLoader: ClassLoader = + SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) + + def getInterpreterClassLoader() = classLoader + + // Set the current Java "context" class loader to this interpreter's class loader + def setContextClassLoader() = classLoader.setAsContext() + + /** Given a simple repl-defined name, returns the real name of + * the class representing it, e.g. for "Bippy" it may return + * + * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy + */ + def generatedName(simpleName: String): Option[String] = { + if (simpleName endsWith "$") optFlatName(simpleName.init) map (_ + "$") + else optFlatName(simpleName) + } + def flatName(id: String) = optFlatName(id) getOrElse id + def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) + + def allDefinedNames = definedNameMap.keys.toList sortBy (_.toString) + def pathToType(id: String): String = pathToName(newTypeName(id)) + def pathToTerm(id: String): String = pathToName(newTermName(id)) + def pathToName(name: Name): String = { + if (definedNameMap contains name) + definedNameMap(name) fullPath name + else name.toString + } + + /** Most recent tree handled which wasn't wholly synthetic. */ + private def mostRecentlyHandledTree: Option[Tree] = { + prevRequests.reverse foreach { req => + req.handlers.reverse foreach { + case x: MemberDefHandler if x.definesValue && !isInternalVarName(x.name.toString) => return Some(x.member) + case _ => () + } + } + None + } + + /** Stubs for work in progress. */ + def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { + for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { + DBG("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) + } + } + + def handleTermRedefinition(name: TermName, old: Request, req: Request) = { + for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { + // Printing the types here has a tendency to cause assertion errors, like + // assertion failed: fatal: has owner value x, but a class owner is required + // so DBG is by-name now to keep it in the family. (It also traps the assertion error, + // but we don't want to unnecessarily risk hosing the compiler's internal state.) + DBG("Redefining term '%s'\n %s -> %s".format(name, t1, t2)) + } + } + def recordRequest(req: Request) { + if (req == null || referencedNameMap == null) + return + + prevRequests += req + req.referencedNames foreach (x => referencedNameMap(x) = req) + + // warning about serially defining companions. It'd be easy + // enough to just redefine them together but that may not always + // be what people want so I'm waiting until I can do it better. + if (!settings.nowarnings.value) { + for { + name <- req.definedNames filterNot (x => req.definedNames contains x.companionName) + oldReq <- definedNameMap get name.companionName + newSym <- req.definedSymbols get name + oldSym <- oldReq.definedSymbols get name.companionName + } { + printMessage("warning: previously defined %s is not a companion to %s.".format(oldSym, newSym)) + printMessage("Companions must be defined together; you may wish to use :paste mode for this.") + } + } + + // Updating the defined name map + req.definedNames foreach { name => + if (definedNameMap contains name) { + if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req) + else handleTermRedefinition(name.toTermName, definedNameMap(name), req) + } + definedNameMap(name) = req + } + } + + /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */ + def parse(line: String): Option[List[Tree]] = { + var justNeedsMore = false + reporter.withIncompleteHandler((pos,msg) => {justNeedsMore = true}) { + // simple parse: just parse it, nothing else + def simpleParse(code: String): List[Tree] = { + reporter.reset() + val unit = new CompilationUnit(new BatchSourceFile("", code)) + val scanner = new syntaxAnalyzer.UnitParser(unit) + + scanner.templateStatSeq(false)._2 + } + val trees = simpleParse(line) + + if (reporter.hasErrors) Some(Nil) // the result did not parse, so stop + else if (justNeedsMore) None + else Some(trees) + } + } + + def isParseable(line: String): Boolean = { + beSilentDuring { + parse(line) match { + case Some(xs) => xs.nonEmpty // parses as-is + case None => true // incomplete + } + } + } + + /** Compile an nsc SourceFile. Returns true if there are + * no compilation errors, or false otherwise. + */ + def compileSources(sources: SourceFile*): Boolean = { + reporter.reset() + new Run() compileSources sources.toList + !reporter.hasErrors + } + + /** Compile a string. Returns true if there are no + * compilation errors, or false otherwise. + */ + def compileString(code: String): Boolean = + compileSources(new BatchSourceFile("