aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorRam Sriharsha <rsriharsha@hw11853.local>2015-07-30 23:02:11 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-30 23:02:11 -0700
commit4e5919bfb47a58bcbda90ae01c1bed2128ded983 (patch)
treeed723467f08be003bd57a97c6909ee48d1b8029c /mllib
parent83670fc9e6fc9c7a6ae68dfdd3f9335ea72f4ab0 (diff)
downloadspark-4e5919bfb47a58bcbda90ae01c1bed2128ded983.tar.gz
spark-4e5919bfb47a58bcbda90ae01c1bed2128ded983.tar.bz2
spark-4e5919bfb47a58bcbda90ae01c1bed2128ded983.zip
[SPARK-7690] [ML] Multiclass classification Evaluator
Multiclass Classification Evaluator for ML Pipelines. F1 score, precision, recall, weighted precision and weighted recall are supported as available metrics. Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #7475 from harsha2010/SPARK-7690 and squashes the following commits: 9bf4ec7 [Ram Sriharsha] fix indentation 3f09a85 [Ram Sriharsha] cleanup doc 16115ae [Ram Sriharsha] code review fixes 032d2a3 [Ram Sriharsha] fix test eec9865 [Ram Sriharsha] Fix Python Indentation 1dbeffd [Ram Sriharsha] Merge branch 'master' into SPARK-7690 68cea85 [Ram Sriharsha] Merge branch 'master' into SPARK-7690 54c03de [Ram Sriharsha] [SPARK-7690][ml][WIP] Multiclass Evaluator for ML Pipeline
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala85
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala28
2 files changed, 113 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
new file mode 100644
index 0000000000..44f779c190
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
+import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
+import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.sql.{Row, DataFrame}
+import org.apache.spark.sql.types.DoubleType
+
+/**
+ * :: Experimental ::
+ * Evaluator for multiclass classification, which expects two input columns: score and label.
+ */
+@Experimental
+class MulticlassClassificationEvaluator (override val uid: String)
+ extends Evaluator with HasPredictionCol with HasLabelCol {
+
+ def this() = this(Identifiable.randomUID("mcEval"))
+
+ /**
+ * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`,
+ * `"weightedPrecision"`, `"weightedRecall"`)
+ * @group param
+ */
+ val metricName: Param[String] = {
+ val allowedParams = ParamValidators.inArray(Array("f1", "precision",
+ "recall", "weightedPrecision", "weightedRecall"))
+ new Param(this, "metricName", "metric name in evaluation " +
+ "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams)
+ }
+
+ /** @group getParam */
+ def getMetricName: String = $(metricName)
+
+ /** @group setParam */
+ def setMetricName(value: String): this.type = set(metricName, value)
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ setDefault(metricName -> "f1")
+
+ override def evaluate(dataset: DataFrame): Double = {
+ val schema = dataset.schema
+ SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
+ SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+
+ val predictionAndLabels = dataset.select($(predictionCol), $(labelCol))
+ .map { case Row(prediction: Double, label: Double) =>
+ (prediction, label)
+ }
+ val metrics = new MulticlassMetrics(predictionAndLabels)
+ val metric = $(metricName) match {
+ case "f1" => metrics.weightedFMeasure
+ case "precision" => metrics.precision
+ case "recall" => metrics.recall
+ case "weightedPrecision" => metrics.weightedPrecision
+ case "weightedRecall" => metrics.weightedRecall
+ }
+ metric
+ }
+
+ override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
new file mode 100644
index 0000000000..6d8412b0b3
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+
+class MulticlassClassificationEvaluatorSuite extends SparkFunSuite {
+
+ test("params") {
+ ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
+ }
+}