aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTimothy Hunter <timhunter@databricks.com>2017-03-23 18:42:13 -0700
committerJoseph K. Bradley <joseph@databricks.com>2017-03-23 18:42:13 -0700
commitd27daa54bd341b29737a6352d9a1055151248ae7 (patch)
treeaa91b5865e37cb2b715a31acf371878a96bb4518
parent93581fbc18c01595918c565f6737aaa666116114 (diff)
downloadspark-d27daa54bd341b29737a6352d9a1055151248ae7.tar.gz
spark-d27daa54bd341b29737a6352d9a1055151248ae7.tar.bz2
spark-d27daa54bd341b29737a6352d9a1055151248ae7.zip
[SPARK-19636][ML] Feature parity for correlation statistics in MLlib
## What changes were proposed in this pull request? This patch adds the Dataframes-based support for the correlation statistics found in the `org.apache.spark.mllib.stat.correlation.Statistics`, following the design doc discussed in the JIRA ticket. The current implementation is a simple wrapper around the `spark.mllib` implementation. Future optimizations can be implemented at a later stage. ## How was this patch tested? ``` build/sbt "testOnly org.apache.spark.ml.stat.StatisticsSuite" ``` Author: Timothy Hunter <timhunter@databricks.com> Closes #17108 from thunterdb/19636.
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala86
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala77
3 files changed, 171 insertions, 0 deletions
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
index 2327917e2c..30edd00fb5 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -32,6 +32,10 @@ object TestingUtils {
* the relative tolerance is meaningless, so the exception will be raised to warn users.
*/
private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+ // Special case for NaNs
+ if (x.isNaN && y.isNaN) {
+ return true
+ }
val absX = math.abs(x)
val absY = math.abs(y)
val diff = math.abs(x - y)
@@ -49,6 +53,10 @@ object TestingUtils {
* Private helper function for comparing two values using absolute tolerance.
*/
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+ // Special case for NaNs
+ if (x.isNaN && y.isNaN) {
+ return true
+ }
math.abs(x - y) < eps
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
new file mode 100644
index 0000000000..a7243ccbf2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.linalg.{SQLDataTypes, Vector}
+import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
+import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.types.{StructField, StructType}
+
+/**
+ * API for correlation functions in MLlib, compatible with Dataframes and Datasets.
+ *
+ * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]]
+ * to spark.ml's Vector types.
+ */
+@Since("2.2.0")
+@Experimental
+object Correlation {
+
+ /**
+ * :: Experimental ::
+ * Compute the correlation matrix for the input RDD of Vectors using the specified method.
+ * Methods currently supported: `pearson` (default), `spearman`.
+ *
+ * @param dataset A dataset or a dataframe
+ * @param column The name of the column of vectors for which the correlation coefficient needs
+ * to be computed. This must be a column of the dataset, and it must contain
+ * Vector objects.
+ * @param method String specifying the method to use for computing correlation.
+ * Supported: `pearson` (default), `spearman`
+ * @return A dataframe that contains the correlation matrix of the column of vectors. This
+ * dataframe contains a single row and a single column of name
+ * '$METHODNAME($COLUMN)'.
+ * @throws IllegalArgumentException if the column is not a valid column in the dataset, or if
+ * the content of this column is not of type Vector.
+ *
+ * Here is how to access the correlation coefficient:
+ * {{{
+ * val data: Dataset[Vector] = ...
+ * val Row(coeff: Matrix) = Statistics.corr(data, "value").head
+ * // coeff now contains the Pearson correlation matrix.
+ * }}}
+ *
+ * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column
+ * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
+ * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to
+ * avoid recomputing the common lineage.
+ */
+ @Since("2.2.0")
+ def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
+ val rdd = dataset.select(column).rdd.map {
+ case Row(v: Vector) => OldVectors.fromML(v)
+ }
+ val oldM = OldStatistics.corr(rdd, method)
+ val name = s"$method($column)"
+ val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false)))
+ dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema)
+ }
+
+ /**
+ * Compute the Pearson correlation matrix for the input Dataset of Vectors.
+ */
+ @Since("2.2.0")
+ def corr(dataset: Dataset[_], column: String): DataFrame = {
+ corr(dataset, column, "pearson")
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala
new file mode 100644
index 0000000000..7d935e651f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat
+
+import breeze.linalg.{DenseMatrix => BDM}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.linalg.{Matrices, Matrix, Vectors}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+
+class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
+
+ val xData = Array(1.0, 0.0, -2.0)
+ val yData = Array(4.0, 5.0, 3.0)
+ val zeros = new Array[Double](3)
+ val data = Seq(
+ Vectors.dense(1.0, 0.0, 0.0, -2.0),
+ Vectors.dense(4.0, 5.0, 0.0, 3.0),
+ Vectors.dense(6.0, 7.0, 0.0, 8.0),
+ Vectors.dense(9.0, 0.0, 0.0, 1.0)
+ )
+
+ private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features")
+
+ private def extract(df: DataFrame): BDM[Double] = {
+ val Array(Row(mat: Matrix)) = df.collect()
+ mat.asBreeze.toDenseMatrix
+ }
+
+
+ test("corr(X) default, pearson") {
+ val defaultMat = Correlation.corr(X, "features")
+ val pearsonMat = Correlation.corr(X, "features", "pearson")
+ // scalastyle:off
+ val expected = Matrices.fromBreeze(BDM(
+ (1.00000000, 0.05564149, Double.NaN, 0.4004714),
+ (0.05564149, 1.00000000, Double.NaN, 0.9135959),
+ (Double.NaN, Double.NaN, 1.00000000, Double.NaN),
+ (0.40047142, 0.91359586, Double.NaN, 1.0000000)))
+ // scalastyle:on
+
+ assert(Matrices.fromBreeze(extract(defaultMat)) ~== expected absTol 1e-4)
+ assert(Matrices.fromBreeze(extract(pearsonMat)) ~== expected absTol 1e-4)
+ }
+
+ test("corr(X) spearman") {
+ val spearmanMat = Correlation.corr(X, "features", "spearman")
+ // scalastyle:off
+ val expected = Matrices.fromBreeze(BDM(
+ (1.0000000, 0.1054093, Double.NaN, 0.4000000),
+ (0.1054093, 1.0000000, Double.NaN, 0.9486833),
+ (Double.NaN, Double.NaN, 1.00000000, Double.NaN),
+ (0.4000000, 0.9486833, Double.NaN, 1.0000000)))
+ // scalastyle:on
+ assert(Matrices.fromBreeze(extract(spearmanMat)) ~== expected absTol 1e-4)
+ }
+
+}