aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/scala')
-rw-r--r--examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala123
-rw-r--r--examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala101
-rw-r--r--examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala223
3 files changed, 447 insertions, 0 deletions
diff --git a/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala
new file mode 100644
index 0000000000..c23ee9895f
--- /dev/null
+++ b/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.examples.bagel
+
+import spark._
+import spark.SparkContext._
+
+import spark.bagel._
+import spark.bagel.Bagel._
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
+
+import com.esotericsoftware.kryo._
+
+class PageRankUtils extends Serializable {
+ def computeWithCombiner(numVertices: Long, epsilon: Double)(
+ self: PRVertex, messageSum: Option[Double], superstep: Int
+ ): (PRVertex, Array[PRMessage]) = {
+ val newValue = messageSum match {
+ case Some(msgSum) if msgSum != 0 =>
+ 0.15 / numVertices + 0.85 * msgSum
+ case _ => self.value
+ }
+
+ val terminate = superstep >= 10
+
+ val outbox: Array[PRMessage] =
+ if (!terminate)
+ self.outEdges.map(targetId =>
+ new PRMessage(targetId, newValue / self.outEdges.size))
+ else
+ Array[PRMessage]()
+
+ (new PRVertex(newValue, self.outEdges, !terminate), outbox)
+ }
+
+ def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) =
+ computeWithCombiner(numVertices, epsilon)(self, messages match {
+ case Some(msgs) => Some(msgs.map(_.value).sum)
+ case None => None
+ }, superstep)
+}
+
+class PRCombiner extends Combiner[PRMessage, Double] with Serializable {
+ def createCombiner(msg: PRMessage): Double =
+ msg.value
+ def mergeMsg(combiner: Double, msg: PRMessage): Double =
+ combiner + msg.value
+ def mergeCombiners(a: Double, b: Double): Double =
+ a + b
+}
+
+class PRVertex() extends Vertex with Serializable {
+ var value: Double = _
+ var outEdges: Array[String] = _
+ var active: Boolean = _
+
+ def this(value: Double, outEdges: Array[String], active: Boolean = true) {
+ this()
+ this.value = value
+ this.outEdges = outEdges
+ this.active = active
+ }
+
+ override def toString(): String = {
+ "PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString)
+ }
+}
+
+class PRMessage() extends Message[String] with Serializable {
+ var targetId: String = _
+ var value: Double = _
+
+ def this(targetId: String, value: Double) {
+ this()
+ this.targetId = targetId
+ this.value = value
+ }
+}
+
+class PRKryoRegistrator extends KryoRegistrator {
+ def registerClasses(kryo: Kryo) {
+ kryo.register(classOf[PRVertex])
+ kryo.register(classOf[PRMessage])
+ }
+}
+
+class CustomPartitioner(partitions: Int) extends Partitioner {
+ def numPartitions = partitions
+
+ def getPartition(key: Any): Int = {
+ val hash = key match {
+ case k: Long => (k & 0x00000000FFFFFFFFL).toInt
+ case _ => key.hashCode
+ }
+
+ val mod = key.hashCode % partitions
+ if (mod < 0) mod + partitions else mod
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case c: CustomPartitioner =>
+ c.numPartitions == numPartitions
+ case _ => false
+ }
+}
diff --git a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala
new file mode 100644
index 0000000000..00635a7ffa
--- /dev/null
+++ b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.examples.bagel
+
+import spark._
+import spark.SparkContext._
+
+import spark.bagel._
+import spark.bagel.Bagel._
+
+import scala.xml.{XML,NodeSeq}
+
+/**
+ * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles"
+ * files from there, which contains one line per wiki article in a tab-separated format
+ * (http://wiki.freebase.com/wiki/WEX/Documentation#articles).
+ */
+object WikipediaPageRank {
+ def main(args: Array[String]) {
+ if (args.length < 5) {
+ System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numPartitions> <host> <usePartitioner>")
+ System.exit(-1)
+ }
+
+ System.setProperty("spark.serializer", "spark.KryoSerializer")
+ System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName)
+
+ val inputFile = args(0)
+ val threshold = args(1).toDouble
+ val numPartitions = args(2).toInt
+ val host = args(3)
+ val usePartitioner = args(4).toBoolean
+ val sc = new SparkContext(host, "WikipediaPageRank")
+
+ // Parse the Wikipedia page data into a graph
+ val input = sc.textFile(inputFile)
+
+ println("Counting vertices...")
+ val numVertices = input.count()
+ println("Done counting vertices.")
+
+ println("Parsing input file...")
+ var vertices = input.map(line => {
+ val fields = line.split("\t")
+ val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
+ val links =
+ if (body == "\\N")
+ NodeSeq.Empty
+ else
+ try {
+ XML.loadString(body) \\ "link" \ "target"
+ } catch {
+ case e: org.xml.sax.SAXParseException =>
+ System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body)
+ NodeSeq.Empty
+ }
+ val outEdges = links.map(link => new String(link.text)).toArray
+ val id = new String(title)
+ (id, new PRVertex(1.0 / numVertices, outEdges))
+ })
+ if (usePartitioner)
+ vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache
+ else
+ vertices = vertices.cache
+ println("Done parsing input file.")
+
+ // Do the computation
+ val epsilon = 0.01 / numVertices
+ val messages = sc.parallelize(Array[(String, PRMessage)]())
+ val utils = new PageRankUtils
+ val result =
+ Bagel.run(
+ sc, vertices, messages, combiner = new PRCombiner(),
+ numPartitions = numPartitions)(
+ utils.computeWithCombiner(numVertices, epsilon))
+
+ // Print the result
+ System.err.println("Articles with PageRank >= "+threshold+":")
+ val top =
+ (result
+ .filter { case (id, vertex) => vertex.value >= threshold }
+ .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) }
+ .collect.mkString)
+ println(top)
+ }
+}
diff --git a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala
new file mode 100644
index 0000000000..c416ddbc58
--- /dev/null
+++ b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.examples.bagel
+
+import spark._
+import serializer.{DeserializationStream, SerializationStream, SerializerInstance}
+import spark.SparkContext._
+
+import spark.bagel._
+import spark.bagel.Bagel._
+
+import scala.xml.{XML,NodeSeq}
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
+import java.nio.ByteBuffer
+
+object WikipediaPageRankStandalone {
+ def main(args: Array[String]) {
+ if (args.length < 5) {
+ System.err.println("Usage: WikipediaPageRankStandalone <inputFile> <threshold> <numIterations> <host> <usePartitioner>")
+ System.exit(-1)
+ }
+
+ System.setProperty("spark.serializer", "spark.bagel.examples.WPRSerializer")
+
+ val inputFile = args(0)
+ val threshold = args(1).toDouble
+ val numIterations = args(2).toInt
+ val host = args(3)
+ val usePartitioner = args(4).toBoolean
+ val sc = new SparkContext(host, "WikipediaPageRankStandalone")
+
+ val input = sc.textFile(inputFile)
+ val partitioner = new HashPartitioner(sc.defaultParallelism)
+ val links =
+ if (usePartitioner)
+ input.map(parseArticle _).partitionBy(partitioner).cache()
+ else
+ input.map(parseArticle _).cache()
+ val n = links.count()
+ val defaultRank = 1.0 / n
+ val a = 0.15
+
+ // Do the computation
+ val startTime = System.currentTimeMillis
+ val ranks =
+ pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, sc.defaultParallelism)
+
+ // Print the result
+ System.err.println("Articles with PageRank >= "+threshold+":")
+ val top =
+ (ranks
+ .filter { case (id, rank) => rank >= threshold }
+ .map { case (id, rank) => "%s\t%s\n".format(id, rank) }
+ .collect().mkString)
+ println(top)
+
+ val time = (System.currentTimeMillis - startTime) / 1000.0
+ println("Completed %d iterations in %f seconds: %f seconds per iteration"
+ .format(numIterations, time, time / numIterations))
+ System.exit(0)
+ }
+
+ def parseArticle(line: String): (String, Array[String]) = {
+ val fields = line.split("\t")
+ val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
+ val id = new String(title)
+ val links =
+ if (body == "\\N")
+ NodeSeq.Empty
+ else
+ try {
+ XML.loadString(body) \\ "link" \ "target"
+ } catch {
+ case e: org.xml.sax.SAXParseException =>
+ System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body)
+ NodeSeq.Empty
+ }
+ val outEdges = links.map(link => new String(link.text)).toArray
+ (id, outEdges)
+ }
+
+ def pageRank(
+ links: RDD[(String, Array[String])],
+ numIterations: Int,
+ defaultRank: Double,
+ a: Double,
+ n: Long,
+ partitioner: Partitioner,
+ usePartitioner: Boolean,
+ numPartitions: Int
+ ): RDD[(String, Double)] = {
+ var ranks = links.mapValues { edges => defaultRank }
+ for (i <- 1 to numIterations) {
+ val contribs = links.groupWith(ranks).flatMap {
+ case (id, (linksWrapper, rankWrapper)) =>
+ if (linksWrapper.length > 0) {
+ if (rankWrapper.length > 0) {
+ linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size))
+ } else {
+ linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size))
+ }
+ } else {
+ Array[(String, Double)]()
+ }
+ }
+ ranks = (contribs.combineByKey((x: Double) => x,
+ (x: Double, y: Double) => x + y,
+ (x: Double, y: Double) => x + y,
+ partitioner)
+ .mapValues(sum => a/n + (1-a)*sum))
+ }
+ ranks
+ }
+}
+
+class WPRSerializer extends spark.serializer.Serializer {
+ def newInstance(): SerializerInstance = new WPRSerializerInstance()
+}
+
+class WPRSerializerInstance extends SerializerInstance {
+ def serialize[T](t: T): ByteBuffer = {
+ throw new UnsupportedOperationException()
+ }
+
+ def deserialize[T](bytes: ByteBuffer): T = {
+ throw new UnsupportedOperationException()
+ }
+
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ throw new UnsupportedOperationException()
+ }
+
+ def serializeStream(s: OutputStream): SerializationStream = {
+ new WPRSerializationStream(s)
+ }
+
+ def deserializeStream(s: InputStream): DeserializationStream = {
+ new WPRDeserializationStream(s)
+ }
+}
+
+class WPRSerializationStream(os: OutputStream) extends SerializationStream {
+ val dos = new DataOutputStream(os)
+
+ def writeObject[T](t: T): SerializationStream = t match {
+ case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match {
+ case links: Array[String] => {
+ dos.writeInt(0) // links
+ dos.writeUTF(id)
+ dos.writeInt(links.length)
+ for (link <- links) {
+ dos.writeUTF(link)
+ }
+ this
+ }
+ case rank: Double => {
+ dos.writeInt(1) // rank
+ dos.writeUTF(id)
+ dos.writeDouble(rank)
+ this
+ }
+ }
+ case (id: String, rank: Double) => {
+ dos.writeInt(2) // rank without wrapper
+ dos.writeUTF(id)
+ dos.writeDouble(rank)
+ this
+ }
+ }
+
+ def flush() { dos.flush() }
+ def close() { dos.close() }
+}
+
+class WPRDeserializationStream(is: InputStream) extends DeserializationStream {
+ val dis = new DataInputStream(is)
+
+ def readObject[T](): T = {
+ val typeId = dis.readInt()
+ typeId match {
+ case 0 => {
+ val id = dis.readUTF()
+ val numLinks = dis.readInt()
+ val links = new Array[String](numLinks)
+ for (i <- 0 until numLinks) {
+ val link = dis.readUTF()
+ links(i) = link
+ }
+ (id, ArrayBuffer(links)).asInstanceOf[T]
+ }
+ case 1 => {
+ val id = dis.readUTF()
+ val rank = dis.readDouble()
+ (id, ArrayBuffer(rank)).asInstanceOf[T]
+ }
+ case 2 => {
+ val id = dis.readUTF()
+ val rank = dis.readDouble()
+ (id, rank).asInstanceOf[T]
+ }
+ }
+ }
+
+ def close() { dis.close() }
+}