aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorRam Sriharsha <rsriharsha@hw11853.local>2015-05-22 09:59:44 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-22 09:59:51 -0700
commitd709d7cebd1cd5b27f3b6d15629c3e88367acae1 (patch)
tree99810811a5b839d7f0e9d9d1a87a2ee9e614428e /mllib/src/test
parent427dc04c1e9fa6c30dd899d28ad8261896b2f07e (diff)
downloadspark-d709d7cebd1cd5b27f3b6d15629c3e88367acae1.tar.gz
spark-d709d7cebd1cd5b27f3b6d15629c3e88367acae1.tar.bz2
spark-d709d7cebd1cd5b27f3b6d15629c3e88367acae1.zip
[SPARK-7404] [ML] Add RegressionEvaluator to spark.ml
Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #6344 from harsha2010/SPARK-7404 and squashes the following commits: 16b9d77 [Ram Sriharsha] consistent naming 7f100b6 [Ram Sriharsha] cleanup c46044d [Ram Sriharsha] Merge with Master + Code Review Fixes 188fa0a [Ram Sriharsha] Merge branch 'master' into SPARK-7404 f5b6a4c [Ram Sriharsha] cleanup doc 97beca5 [Ram Sriharsha] update test to use R packages 32dd310 [Ram Sriharsha] fix indentation f93b812 [Ram Sriharsha] fix test 1b6ebb3 [Ram Sriharsha] [SPARK-7404][ml] Add RegressionEvaluator to spark.ml (cherry picked from commit f490b3b4c706c92aa65d000b9d885f4d160a5f39) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala71
1 files changed, 71 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
new file mode 100644
index 0000000000..983f8b460b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.mllib.util.TestingUtils._
+
+class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
+
+ test("Regression Evaluator: default params") {
+ /**
+ * Here is the instruction describing how to export the test data into CSV format
+ * so we can validate the metrics compared with R's mmetric package.
+ *
+ * import org.apache.spark.mllib.util.LinearDataGenerator
+ * val data = sc.parallelize(LinearDataGenerator.generateLinearInput(6.3,
+ * Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1))
+ * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1))
+ * .saveAsTextFile("path")
+ */
+ val dataset = sqlContext.createDataFrame(
+ sc.parallelize(LinearDataGenerator.generateLinearInput(
+ 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
+ /**
+ * Using the following R code to load the data, train the model and evaluate metrics.
+ *
+ * > library("glmnet")
+ * > library("rminer")
+ * > data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+ * > features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
+ * > label <- as.numeric(data$V1)
+ * > model <- glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)
+ * > rmse <- mmetric(label, predict(model, features), metric='RMSE')
+ * > mae <- mmetric(label, predict(model, features), metric='MAE')
+ * > r2 <- mmetric(label, predict(model, features), metric='R2')
+ */
+ val trainer = new LinearRegression
+ val model = trainer.fit(dataset)
+ val predictions = model.transform(dataset)
+
+ // default = rmse
+ val evaluator = new RegressionEvaluator()
+ assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001)
+
+ // r2 score
+ evaluator.setMetricName("r2")
+ assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001)
+
+ // mae
+ evaluator.setMetricName("mae")
+ assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
+ }
+}