aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-05 23:43:47 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-05 23:43:47 -0800
commitdc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f (patch)
tree745d33737eaddc95a0c55a814e84c7b96f9ecbcf /mllib/src/test/scala/org
parent6b88825a25a0a072c13bbcc57bbfdb102a3f133d (diff)
downloadspark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.tar.gz
spark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.tar.bz2
spark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.zip
[SPARK-4789] [SPARK-4942] [SPARK-5031] [mllib] Standardize ML Prediction APIs
This is part (1a) of the updates from the design doc in [https://docs.google.com/document/d/1BH9el33kBX8JiDdgUJXdLW14CA2qhTCWIG46eXZVoJs] **UPDATE**: Most of the APIs are being kept private[spark] to allow further discussion. Here is a list of changes which are public: * new output columns: rawPrediction, probabilities * The “score” column is now called “rawPrediction” * Classifiers now provide numClasses * Params.get and .set are now protected instead of private[ml]. * ParamMap now has a size method. * new classes: LinearRegression, LinearRegressionModel * LogisticRegression now has an intercept. ### Sketch of APIs (most of which are private[spark] for now) Abstract classes for learning algorithms (+ corresponding Model abstractions): * Classifier (+ ClassificationModel) * ProbabilisticClassifier (+ ProbabilisticClassificationModel) * Regressor (+ RegressionModel) * Predictor (+ PredictionModel) * *For all of these*: * There is no strongly typed training-time API. * There is a strongly typed test-time (prediction) API which helps developers implement new algorithms. Concrete classes: learning algorithms * LinearRegression * LogisticRegression (updated to use new abstract classes) * Also, removed "score" in favor of "probability" output column. Changed BinaryClassificationEvaluator to match. (SPARK-5031) Other updates: * params.scala: Changed Params.set/get to be protected instead of private[ml] * This was needed for the example of defining a class from outside of the MLlib namespace. * VectorUDT: Will later change from private[spark] to public. * This is needed for outside users to write their own validateAndTransformSchema() methods using vectors. * Also, added equals() method.f * SPARK-4942 : ML Transformers should allow output cols to be turned on,off * Update validateAndTransformSchema * Update transform * (Updated examples, test suites according to other changes) New examples: * DeveloperApiExample.scala (example of defining algorithm from outside of the MLlib namespace) * Added Java version too Test Suites: * LinearRegressionSuite * LogisticRegressionSuite * + Java versions of above suites CC: mengxr etrain shivaram Author: Joseph K. Bradley <joseph@databricks.com> Closes #3637 from jkbradley/ml-api-part1 and squashes the following commits: 405bfb8 [Joseph K. Bradley] Last edits based on code review. Small cleanups fec348a [Joseph K. Bradley] Added JavaDeveloperApiExample.java and fixed other issues: Made developer API private[spark] for now. Added constructors Java can understand to specialized Param types. 8316d5e [Joseph K. Bradley] fixes after rebasing on master fc62406 [Joseph K. Bradley] fixed test suites after last commit bcb9549 [Joseph K. Bradley] Fixed issues after rebasing from master (after move from SchemaRDD to DataFrame) 9872424 [Joseph K. Bradley] fixed JavaLinearRegressionSuite.java Java sql api f542997 [Joseph K. Bradley] Added MIMA excludes for VectorUDT (now public), and added DeveloperApi annotation to it 216d199 [Joseph K. Bradley] fixed after sql datatypes PR got merged f549e34 [Joseph K. Bradley] Updates based on code review. Major ones are: * Created weakly typed Predictor.train() method which is called by fit() so that developers do not have to call schema validation or copy parameters. * Made Predictor.featuresDataType have a default value of VectorUDT. * NOTE: This could be dangerous since the FeaturesType type parameter cannot have a default value. 343e7bd [Joseph K. Bradley] added blanket mima exclude for ml package 82f340b [Joseph K. Bradley] Fixed bug in LogisticRegression (introduced in this PR). Fixed Java suites 0a16da9 [Joseph K. Bradley] Fixed Linear/Logistic RegressionSuites c3c8da5 [Joseph K. Bradley] small cleanup 934f97b [Joseph K. Bradley] Fixed bugs from previous commit. 1c61723 [Joseph K. Bradley] * Made ProbabilisticClassificationModel into a subclass of ClassificationModel. Also introduced ProbabilisticClassifier. * This was to support output column “probabilityCol” in transform(). 4e2f711 [Joseph K. Bradley] rat fix bc654e1 [Joseph K. Bradley] Added spark.ml LinearRegressionSuite 8d13233 [Joseph K. Bradley] Added methods: * Classifier: batch predictRaw() * Predictor: train() without paramMap ProbabilisticClassificationModel.predictProbabilities() * Java versions of all above batch methods + others 1680905 [Joseph K. Bradley] Added JavaLabeledPointSuite.java for spark.ml, and added constructor to LabeledPoint which defaults weight to 1.0 adbe50a [Joseph K. Bradley] * fixed LinearRegression train() to use embedded paramMap * added Predictor.predict(RDD[Vector]) method * updated Linear/LogisticRegressionSuites 58802e3 [Joseph K. Bradley] added train() to Predictor subclasses which does not take a ParamMap. 57d54ab [Joseph K. Bradley] * Changed semantics of Predictor.train() to merge the given paramMap with the embedded paramMap. * remove threshold_internal from logreg * Added Predictor.copy() * Extended LogisticRegressionSuite e433872 [Joseph K. Bradley] Updated docs. Added LabeledPointSuite to spark.ml 54b7b31 [Joseph K. Bradley] Fixed issue with logreg threshold being set correctly 0617d61 [Joseph K. Bradley] Fixed bug from last commit (sorting paramMap by parameter names in toString). Fixed bug in persisting logreg data. Added threshold_internal to logreg for faster test-time prediction (avoiding map lookup). 601e792 [Joseph K. Bradley] Modified ParamMap to sort parameters in toString. Cleaned up classes in class hierarchy, before implementing tests and examples. d705e87 [Joseph K. Bradley] Added LinearRegression and Regressor back from ml-api branch 52f4fde [Joseph K. Bradley] removing everything except for simple class hierarchy for classification d35bb5d [Joseph K. Bradley] fixed compilation issues, but have not added tests yet bfade12 [Joseph K. Bradley] Added lots of classes for new ML API:
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala86
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala65
2 files changed, 140 insertions, 11 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 33e40dc741..b3d1bfcfbe 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -20,44 +20,108 @@ package org.apache.spark.ml.classification
import org.scalatest.FunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, DataFrame}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _
+ private val eps: Double = 1e-5
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
dataset = sqlContext.createDataFrame(
- sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
+ sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
}
- test("logistic regression") {
+ test("logistic regression: default params") {
val lr = new LogisticRegression
+ assert(lr.getLabelCol == "label")
+ assert(lr.getFeaturesCol == "features")
+ assert(lr.getPredictionCol == "prediction")
+ assert(lr.getRawPredictionCol == "rawPrediction")
+ assert(lr.getProbabilityCol == "probability")
val model = lr.fit(dataset)
model.transform(dataset)
- .select("label", "prediction")
+ .select("label", "probability", "prediction", "rawPrediction")
.collect()
+ assert(model.getThreshold === 0.5)
+ assert(model.getFeaturesCol == "features")
+ assert(model.getPredictionCol == "prediction")
+ assert(model.getRawPredictionCol == "rawPrediction")
+ assert(model.getProbabilityCol == "probability")
}
test("logistic regression with setters") {
+ // Set params, train, and check as many params as we can.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
+ .setThreshold(0.6)
+ .setProbabilityCol("myProbability")
val model = lr.fit(dataset)
- model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
- .select("label", "score", "prediction")
+ assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
+ assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
+ assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
+ assert(model.getThreshold === 0.6)
+
+ // Modify model params, and check that the params worked.
+ model.setThreshold(1.0)
+ val predAllZero = model.transform(dataset)
+ .select("prediction", "myProbability")
.collect()
+ .map { case Row(pred: Double, prob: Vector) => pred }
+ assert(predAllZero.forall(_ === 0),
+ s"With threshold=1.0, expected predictions to be all 0, but only" +
+ s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
+ // Call transform with params, and check that the params worked.
+ val predNotAllZero =
+ model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb")
+ .select("prediction", "myProb")
+ .collect()
+ .map { case Row(pred: Double, prob: Vector) => pred }
+ assert(predNotAllZero.exists(_ !== 0.0))
+
+ // Call fit() with new params, and check as many params as we can.
+ val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
+ lr.probabilityCol -> "theProb")
+ assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
+ assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
+ assert(model2.fittingParamMap.get(lr.threshold).get === 0.4)
+ assert(model2.getThreshold === 0.4)
+ assert(model2.getProbabilityCol == "theProb")
}
- test("logistic regression fit and transform with varargs") {
+ test("logistic regression: Predictor, Classifier methods") {
+ val sqlContext = this.sqlContext
val lr = new LogisticRegression
- val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
- model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
- .select("label", "probability", "prediction")
- .collect()
+
+ val model = lr.fit(dataset)
+ assert(model.numClasses === 2)
+
+ val threshold = model.getThreshold
+ val results = model.transform(dataset)
+
+ // Compare rawPrediction with probability
+ results.select("rawPrediction", "probability").collect().map {
+ case Row(raw: Vector, prob: Vector) =>
+ assert(raw.size === 2)
+ assert(prob.size === 2)
+ val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1)))
+ assert(prob(1) ~== probFromRaw1 relTol eps)
+ assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps)
+ }
+
+ // Compare prediction with probability
+ results.select("prediction", "probability").collect().map {
+ case Row(pred: Double, prob: Vector) =>
+ val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
+ assert(pred == predFromProb)
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
new file mode 100644
index 0000000000..bbb44c3e2d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var sqlContext: SQLContext = _
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ dataset = sqlContext.createDataFrame(
+ sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
+ }
+
+ test("linear regression: default params") {
+ val lr = new LinearRegression
+ assert(lr.getLabelCol == "label")
+ val model = lr.fit(dataset)
+ model.transform(dataset)
+ .select("label", "prediction")
+ .collect()
+ // Check defaults
+ assert(model.getFeaturesCol == "features")
+ assert(model.getPredictionCol == "prediction")
+ }
+
+ test("linear regression with setters") {
+ // Set params, train, and check as many as we can.
+ val lr = new LinearRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ val model = lr.fit(dataset)
+ assert(model.fittingParamMap.get(lr.maxIter).get === 10)
+ assert(model.fittingParamMap.get(lr.regParam).get === 1.0)
+
+ // Call fit() with new params, and check as many as we can.
+ val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred")
+ assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
+ assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
+ assert(model2.getPredictionCol == "thePred")
+ }
+}