aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2014-01-06 03:05:52 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2014-01-06 03:05:52 -0800
commit3b4c4c7f4d0d6e45a1acb0baf0d9416a8997b686 (patch)
tree98b744beacdb06925eba8413288c473b13f49174 /mllib
parentd0fd3b9ad238294346eb3465c489eabd41fb2380 (diff)
parenta2e7e0497484554f86bd71e93705eb0422b1512b (diff)
downloadspark-3b4c4c7f4d0d6e45a1acb0baf0d9416a8997b686.tar.gz
spark-3b4c4c7f4d0d6e45a1acb0baf0d9416a8997b686.tar.bz2
spark-3b4c4c7f4d0d6e45a1acb0baf0d9416a8997b686.zip
Merge remote-tracking branch 'apache/master' into project-refactor
Conflicts: examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala119
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala108
3 files changed, 233 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
new file mode 100644
index 0000000000..524300d6ae
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.classification
+
+import scala.collection.mutable
+
+import org.jblas.DoubleMatrix
+
+import org.apache.spark.Logging
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+
+/**
+ * Model for Naive Bayes Classifiers.
+ *
+ * @param pi Log of class priors, whose dimension is C.
+ * @param theta Log of class conditional probabilities, whose dimension is CXD.
+ */
+class NaiveBayesModel(pi: Array[Double], theta: Array[Array[Double]])
+ extends ClassificationModel with Serializable {
+
+ // Create a column vector that can be used for predictions
+ private val _pi = new DoubleMatrix(pi.length, 1, pi: _*)
+ private val _theta = new DoubleMatrix(theta)
+
+ def predict(testData: RDD[Array[Double]]): RDD[Double] = testData.map(predict)
+
+ def predict(testData: Array[Double]): Double = {
+ val dataMatrix = new DoubleMatrix(testData.length, 1, testData: _*)
+ val result = _pi.add(_theta.mmul(dataMatrix))
+ result.argmax()
+ }
+}
+
+/**
+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
+ *
+ * @param lambda The smooth parameter
+ */
+class NaiveBayes private (val lambda: Double = 1.0)
+ extends Serializable with Logging {
+
+ /**
+ * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
+ *
+ * @param data RDD of (label, array of features) pairs.
+ */
+ def run(data: RDD[LabeledPoint]) = {
+ // Aggregates all sample points to driver side to get sample count and summed feature vector
+ // for each label. The shape of `zeroCombiner` & `aggregated` is:
+ //
+ // label: Int -> (count: Int, featuresSum: DoubleMatrix)
+ val zeroCombiner = mutable.Map.empty[Int, (Int, DoubleMatrix)]
+ val aggregated = data.aggregate(zeroCombiner)({ (combiner, point) =>
+ point match {
+ case LabeledPoint(label, features) =>
+ val (count, featuresSum) = combiner.getOrElse(label.toInt, (0, DoubleMatrix.zeros(1)))
+ val fs = new DoubleMatrix(features.length, 1, features: _*)
+ combiner += label.toInt -> (count + 1, featuresSum.addi(fs))
+ }
+ }, { (lhs, rhs) =>
+ for ((label, (c, fs)) <- rhs) {
+ val (count, featuresSum) = lhs.getOrElse(label, (0, DoubleMatrix.zeros(1)))
+ lhs(label) = (count + c, featuresSum.addi(fs))
+ }
+ lhs
+ })
+
+ // Kinds of label
+ val C = aggregated.size
+ // Total sample count
+ val N = aggregated.values.map(_._1).sum
+
+ val pi = new Array[Double](C)
+ val theta = new Array[Array[Double]](C)
+ val piLogDenom = math.log(N + C * lambda)
+
+ for ((label, (count, fs)) <- aggregated) {
+ val thetaLogDenom = math.log(fs.sum() + fs.length * lambda)
+ pi(label) = math.log(count + lambda) - piLogDenom
+ theta(label) = fs.toArray.map(f => math.log(f + lambda) - thetaLogDenom)
+ }
+
+ new NaiveBayesModel(pi, theta)
+ }
+}
+
+object NaiveBayes {
+ /**
+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
+ *
+ * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
+ * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
+ * document classification. By making every vector a 0-1 vector. it can also be used as
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ *
+ * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
+ * vector or a count vector.
+ * @param lambda The smooth parameter
+ */
+ def train(input: RDD[LabeledPoint], lambda: Double = 1.0): NaiveBayesModel = {
+ new NaiveBayes(lambda).run(input)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 36853acab5..8b27ecf82c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -578,14 +578,13 @@ object ALS {
val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false
val alpha = if (args.length >= 8) args(7).toDouble else 1
val blocks = if (args.length == 9) args(8).toInt else -1
-
- System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName)
- System.setProperty("spark.kryo.referenceTracking", "false")
- System.setProperty("spark.kryoserializer.buffer.mb", "8")
- System.setProperty("spark.locality.wait", "10000")
-
val sc = new SparkContext(master, "ALS")
+ sc.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ sc.conf.set("spark.kryo.registrator", classOf[ALSRegistrator].getName)
+ sc.conf.set("spark.kryo.referenceTracking", "false")
+ sc.conf.set("spark.kryoserializer.buffer.mb", "8")
+ sc.conf.set("spark.locality.wait", "10000")
+
val ratings = sc.textFile(ratingsFile).map { line =>
val fields = line.split(',')
Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
new file mode 100644
index 0000000000..b615f76e66
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.classification
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.SparkContext
+
+object NaiveBayesSuite {
+
+ private def calcLabel(p: Double, pi: Array[Double]): Int = {
+ var sum = 0.0
+ for (j <- 0 until pi.length) {
+ sum += pi(j)
+ if (p < sum) return j
+ }
+ -1
+ }
+
+ // Generate input of the form Y = (theta * x).argmax()
+ def generateNaiveBayesInput(
+ pi: Array[Double], // 1XC
+ theta: Array[Array[Double]], // CXD
+ nPoints: Int,
+ seed: Int): Seq[LabeledPoint] = {
+ val D = theta(0).length
+ val rnd = new Random(seed)
+
+ val _pi = pi.map(math.pow(math.E, _))
+ val _theta = theta.map(row => row.map(math.pow(math.E, _)))
+
+ for (i <- 0 until nPoints) yield {
+ val y = calcLabel(rnd.nextDouble(), _pi)
+ val xi = Array.tabulate[Double](D) { j =>
+ if (rnd.nextDouble() < _theta(y)(j)) 1 else 0
+ }
+
+ LabeledPoint(y, xi)
+ }
+ }
+}
+
+class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ val numOfPredictions = predictions.zip(input).count {
+ case (prediction, expected) =>
+ prediction != expected.label
+ }
+ // At least 80% of the predictions should be on.
+ assert(numOfPredictions < input.length / 5)
+ }
+
+ test("Naive Bayes") {
+ val nPoints = 10000
+
+ val pi = Array(0.5, 0.3, 0.2).map(math.log)
+ val theta = Array(
+ Array(0.91, 0.03, 0.03, 0.03), // label 0
+ Array(0.03, 0.91, 0.03, 0.03), // label 1
+ Array(0.03, 0.03, 0.91, 0.03) // label 2
+ ).map(_.map(math.log))
+
+ val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42)
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ val model = NaiveBayes.train(testRDD)
+
+ val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+ }
+}