aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorRoyGaoVLIS <roygao@zju.edu.cn>2015-11-17 23:00:49 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-17 23:00:49 -0800
commit67a5132c21bc8338adbae80b33b85e8fa0ddda34 (patch)
tree35b65456ab1f452190cbab48ed38a541fa526b9c /mllib
parent446738e51fcda50cf2dc44123ff6bf12a1611dc0 (diff)
downloadspark-67a5132c21bc8338adbae80b33b85e8fa0ddda34.tar.gz
spark-67a5132c21bc8338adbae80b33b85e8fa0ddda34.tar.bz2
spark-67a5132c21bc8338adbae80b33b85e8fa0ddda34.zip
[SPARK-7013][ML][TEST] Add unit test for spark.ml StandardScaler
I have added unit test for ML's StandardScaler By comparing with R's output, please review for me. Thx. Author: RoyGaoVLIS <roygao@zju.edu.cn> Closes #6665 from RoyGao/7013.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala108
1 files changed, 108 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
new file mode 100644
index 0000000000..879a3ae875
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
+
+ @transient var data: Array[Vector] = _
+ @transient var resWithStd: Array[Vector] = _
+ @transient var resWithMean: Array[Vector] = _
+ @transient var resWithBoth: Array[Vector] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ data = Array(
+ Vectors.dense(-2.0, 2.3, 0.0),
+ Vectors.dense(0.0, -5.1, 1.0),
+ Vectors.dense(1.7, -0.6, 3.3)
+ )
+ resWithMean = Array(
+ Vectors.dense(-1.9, 3.433333333333, -1.433333333333),
+ Vectors.dense(0.1, -3.966666666667, -0.433333333333),
+ Vectors.dense(1.8, 0.533333333333, 1.866666666667)
+ )
+ resWithStd = Array(
+ Vectors.dense(-1.079898494312, 0.616834091415, 0.0),
+ Vectors.dense(0.0, -1.367762550529, 0.590968109266),
+ Vectors.dense(0.917913720165, -0.160913241239, 1.950194760579)
+ )
+ resWithBoth = Array(
+ Vectors.dense(-1.0259035695965, 0.920781324866, -0.8470542899497),
+ Vectors.dense(0.0539949247156, -1.063815317078, -0.256086180682),
+ Vectors.dense(0.9719086448809, 0.143033992212, 1.103140470631)
+ )
+ }
+
+ def assertResult(dataframe: DataFrame): Unit = {
+ dataframe.select("standarded_features", "expected").collect().foreach {
+ case Row(vector1: Vector, vector2: Vector) =>
+ assert(vector1 ~== vector2 absTol 1E-5,
+ "The vector value is not correct after standardization.")
+ }
+ }
+
+ test("Standardization with default parameter") {
+ val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
+
+ val standardscaler0 = new StandardScaler()
+ .setInputCol("features")
+ .setOutputCol("standarded_features")
+ .fit(df0)
+
+ assertResult(standardscaler0.transform(df0))
+ }
+
+ test("Standardization with setter") {
+ val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
+ val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
+ val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
+
+ val standardscaler1 = new StandardScaler()
+ .setInputCol("features")
+ .setOutputCol("standarded_features")
+ .setWithMean(true)
+ .setWithStd(true)
+ .fit(df1)
+
+ val standardscaler2 = new StandardScaler()
+ .setInputCol("features")
+ .setOutputCol("standarded_features")
+ .setWithMean(true)
+ .setWithStd(false)
+ .fit(df2)
+
+ val standardscaler3 = new StandardScaler()
+ .setInputCol("features")
+ .setOutputCol("standarded_features")
+ .setWithMean(false)
+ .setWithStd(false)
+ .fit(df3)
+
+ assertResult(standardscaler1.transform(df1))
+ assertResult(standardscaler2.transform(df2))
+ assertResult(standardscaler3.transform(df3))
+ }
+}