diff options
Diffstat (limited to 'examples/src/main/scala')
3 files changed, 447 insertions, 0 deletions
diff --git a/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala new file mode 100644 index 0000000000..c23ee9895f --- /dev/null +++ b/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.examples.bagel + +import spark._ +import spark.SparkContext._ + +import spark.bagel._ +import spark.bagel.Bagel._ + +import scala.collection.mutable.ArrayBuffer + +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} + +import com.esotericsoftware.kryo._ + +class PageRankUtils extends Serializable { + def computeWithCombiner(numVertices: Long, epsilon: Double)( + self: PRVertex, messageSum: Option[Double], superstep: Int + ): (PRVertex, Array[PRMessage]) = { + val newValue = messageSum match { + case Some(msgSum) if msgSum != 0 => + 0.15 / numVertices + 0.85 * msgSum + case _ => self.value + } + + val terminate = superstep >= 10 + + val outbox: Array[PRMessage] = + if (!terminate) + self.outEdges.map(targetId => + new PRMessage(targetId, newValue / self.outEdges.size)) + else + Array[PRMessage]() + + (new PRVertex(newValue, self.outEdges, !terminate), outbox) + } + + def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) = + computeWithCombiner(numVertices, epsilon)(self, messages match { + case Some(msgs) => Some(msgs.map(_.value).sum) + case None => None + }, superstep) +} + +class PRCombiner extends Combiner[PRMessage, Double] with Serializable { + def createCombiner(msg: PRMessage): Double = + msg.value + def mergeMsg(combiner: Double, msg: PRMessage): Double = + combiner + msg.value + def mergeCombiners(a: Double, b: Double): Double = + a + b +} + +class PRVertex() extends Vertex with Serializable { + var value: Double = _ + var outEdges: Array[String] = _ + var active: Boolean = _ + + def this(value: Double, outEdges: Array[String], active: Boolean = true) { + this() + this.value = value + this.outEdges = outEdges + this.active = active + } + + override def toString(): String = { + "PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString) + } +} + +class PRMessage() extends Message[String] with Serializable { + var targetId: String = _ + var value: Double = _ + + def this(targetId: String, value: Double) { + this() + this.targetId = targetId + this.value = value + } +} + +class PRKryoRegistrator extends KryoRegistrator { + def registerClasses(kryo: Kryo) { + kryo.register(classOf[PRVertex]) + kryo.register(classOf[PRMessage]) + } +} + +class CustomPartitioner(partitions: Int) extends Partitioner { + def numPartitions = partitions + + def getPartition(key: Any): Int = { + val hash = key match { + case k: Long => (k & 0x00000000FFFFFFFFL).toInt + case _ => key.hashCode + } + + val mod = key.hashCode % partitions + if (mod < 0) mod + partitions else mod + } + + override def equals(other: Any): Boolean = other match { + case c: CustomPartitioner => + c.numPartitions == numPartitions + case _ => false + } +} diff --git a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala new file mode 100644 index 0000000000..00635a7ffa --- /dev/null +++ b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.examples.bagel + +import spark._ +import spark.SparkContext._ + +import spark.bagel._ +import spark.bagel.Bagel._ + +import scala.xml.{XML,NodeSeq} + +/** + * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" + * files from there, which contains one line per wiki article in a tab-separated format + * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). + */ +object WikipediaPageRank { + def main(args: Array[String]) { + if (args.length < 5) { + System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numPartitions> <host> <usePartitioner>") + System.exit(-1) + } + + System.setProperty("spark.serializer", "spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) + + val inputFile = args(0) + val threshold = args(1).toDouble + val numPartitions = args(2).toInt + val host = args(3) + val usePartitioner = args(4).toBoolean + val sc = new SparkContext(host, "WikipediaPageRank") + + // Parse the Wikipedia page data into a graph + val input = sc.textFile(inputFile) + + println("Counting vertices...") + val numVertices = input.count() + println("Done counting vertices.") + + println("Parsing input file...") + var vertices = input.map(line => { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = links.map(link => new String(link.text)).toArray + val id = new String(title) + (id, new PRVertex(1.0 / numVertices, outEdges)) + }) + if (usePartitioner) + vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache + else + vertices = vertices.cache + println("Done parsing input file.") + + // Do the computation + val epsilon = 0.01 / numVertices + val messages = sc.parallelize(Array[(String, PRMessage)]()) + val utils = new PageRankUtils + val result = + Bagel.run( + sc, vertices, messages, combiner = new PRCombiner(), + numPartitions = numPartitions)( + utils.computeWithCombiner(numVertices, epsilon)) + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = + (result + .filter { case (id, vertex) => vertex.value >= threshold } + .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } + .collect.mkString) + println(top) + } +} diff --git a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala new file mode 100644 index 0000000000..c416ddbc58 --- /dev/null +++ b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.examples.bagel + +import spark._ +import serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import spark.SparkContext._ + +import spark.bagel._ +import spark.bagel.Bagel._ + +import scala.xml.{XML,NodeSeq} + +import scala.collection.mutable.ArrayBuffer + +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} +import java.nio.ByteBuffer + +object WikipediaPageRankStandalone { + def main(args: Array[String]) { + if (args.length < 5) { + System.err.println("Usage: WikipediaPageRankStandalone <inputFile> <threshold> <numIterations> <host> <usePartitioner>") + System.exit(-1) + } + + System.setProperty("spark.serializer", "spark.bagel.examples.WPRSerializer") + + val inputFile = args(0) + val threshold = args(1).toDouble + val numIterations = args(2).toInt + val host = args(3) + val usePartitioner = args(4).toBoolean + val sc = new SparkContext(host, "WikipediaPageRankStandalone") + + val input = sc.textFile(inputFile) + val partitioner = new HashPartitioner(sc.defaultParallelism) + val links = + if (usePartitioner) + input.map(parseArticle _).partitionBy(partitioner).cache() + else + input.map(parseArticle _).cache() + val n = links.count() + val defaultRank = 1.0 / n + val a = 0.15 + + // Do the computation + val startTime = System.currentTimeMillis + val ranks = + pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, sc.defaultParallelism) + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = + (ranks + .filter { case (id, rank) => rank >= threshold } + .map { case (id, rank) => "%s\t%s\n".format(id, rank) } + .collect().mkString) + println(top) + + val time = (System.currentTimeMillis - startTime) / 1000.0 + println("Completed %d iterations in %f seconds: %f seconds per iteration" + .format(numIterations, time, time / numIterations)) + System.exit(0) + } + + def parseArticle(line: String): (String, Array[String]) = { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val id = new String(title) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = links.map(link => new String(link.text)).toArray + (id, outEdges) + } + + def pageRank( + links: RDD[(String, Array[String])], + numIterations: Int, + defaultRank: Double, + a: Double, + n: Long, + partitioner: Partitioner, + usePartitioner: Boolean, + numPartitions: Int + ): RDD[(String, Double)] = { + var ranks = links.mapValues { edges => defaultRank } + for (i <- 1 to numIterations) { + val contribs = links.groupWith(ranks).flatMap { + case (id, (linksWrapper, rankWrapper)) => + if (linksWrapper.length > 0) { + if (rankWrapper.length > 0) { + linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size)) + } else { + linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size)) + } + } else { + Array[(String, Double)]() + } + } + ranks = (contribs.combineByKey((x: Double) => x, + (x: Double, y: Double) => x + y, + (x: Double, y: Double) => x + y, + partitioner) + .mapValues(sum => a/n + (1-a)*sum)) + } + ranks + } +} + +class WPRSerializer extends spark.serializer.Serializer { + def newInstance(): SerializerInstance = new WPRSerializerInstance() +} + +class WPRSerializerInstance extends SerializerInstance { + def serialize[T](t: T): ByteBuffer = { + throw new UnsupportedOperationException() + } + + def deserialize[T](bytes: ByteBuffer): T = { + throw new UnsupportedOperationException() + } + + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + throw new UnsupportedOperationException() + } + + def serializeStream(s: OutputStream): SerializationStream = { + new WPRSerializationStream(s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new WPRDeserializationStream(s) + } +} + +class WPRSerializationStream(os: OutputStream) extends SerializationStream { + val dos = new DataOutputStream(os) + + def writeObject[T](t: T): SerializationStream = t match { + case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { + case links: Array[String] => { + dos.writeInt(0) // links + dos.writeUTF(id) + dos.writeInt(links.length) + for (link <- links) { + dos.writeUTF(link) + } + this + } + case rank: Double => { + dos.writeInt(1) // rank + dos.writeUTF(id) + dos.writeDouble(rank) + this + } + } + case (id: String, rank: Double) => { + dos.writeInt(2) // rank without wrapper + dos.writeUTF(id) + dos.writeDouble(rank) + this + } + } + + def flush() { dos.flush() } + def close() { dos.close() } +} + +class WPRDeserializationStream(is: InputStream) extends DeserializationStream { + val dis = new DataInputStream(is) + + def readObject[T](): T = { + val typeId = dis.readInt() + typeId match { + case 0 => { + val id = dis.readUTF() + val numLinks = dis.readInt() + val links = new Array[String](numLinks) + for (i <- 0 until numLinks) { + val link = dis.readUTF() + links(i) = link + } + (id, ArrayBuffer(links)).asInstanceOf[T] + } + case 1 => { + val id = dis.readUTF() + val rank = dis.readDouble() + (id, ArrayBuffer(rank)).asInstanceOf[T] + } + case 2 => { + val id = dis.readUTF() + val rank = dis.readDouble() + (id, rank).asInstanceOf[T] + } + } + } + + def close() { dis.close() } +} |