aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-06-19 09:46:51 -0700
committerXiangrui Meng <meng@databricks.com>2015-06-19 10:05:07 -0700
commit1f2dafb77f9af52602885cd5767032a20b486b98 (patch)
treeae4d9b5bf3f102325f8d763fc4f3ba437726fb05 /mllib/src/test
parent164b9d32e764b2a67b372a3d685b57c4bbeccbfa (diff)
downloadspark-1f2dafb77f9af52602885cd5767032a20b486b98.tar.gz
spark-1f2dafb77f9af52602885cd5767032a20b486b98.tar.bz2
spark-1f2dafb77f9af52602885cd5767032a20b486b98.zip
[SPARK-8151] [MLLIB] pipeline components should correctly implement copy
Otherwise, extra params get ignored in `PipelineModel.transform`. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6622 from mengxr/SPARK-8087 and squashes the following commits: 0e4c8c4 [Xiangrui Meng] fix merge issues 26fc1f0 [Xiangrui Meng] address comments e607a04 [Xiangrui Meng] merge master b85b57e [Xiangrui Meng] fix examples/compile d6f7891 [Xiangrui Meng] rename defaultCopyWithParams to defaultCopy 84ec278 [Xiangrui Meng] remove setter checks due to generics 2cf2ed0 [Xiangrui Meng] snapshot 291814f [Xiangrui Meng] OneVsRest.copy 1dfe3bd [Xiangrui Meng] PipelineModel.copy should copy stages (cherry picked from commit 43c7ec6384e51105dedf3a53354b6a3732cc27b2) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala22
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala4
24 files changed, 207 insertions, 19 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index ff5929235a..3ae09d39ef 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -102,4 +102,9 @@ public class JavaTestParams extends JavaParams {
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
}
+
+ @Override
+ public JavaTestParams copy(ParamMap extra) {
+ return defaultCopy(extra);
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 05bf58e63a..a3c4f528da 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -22,6 +22,7 @@ import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame
@@ -81,4 +82,13 @@ class PipelineSuite extends SparkFunSuite {
pipeline.fit(dataset)
}
}
+
+ test("PipelineModel.copy") {
+ val hashingTF = new HashingTF()
+ .setNumFeatures(100)
+ val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
+ val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
+ require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
+ "copy should handle extra stage params")
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index ae40b0b8ff..73b4805c4c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -19,15 +19,15 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
- DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeClassifierSuite.compareAPIs
@@ -55,6 +55,12 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
}
+ test("params") {
+ ParamsSuite.checkParams(new DecisionTreeClassifier)
+ val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
+ ParamsSuite.checkParams(model)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 1302da3c37..82c345491b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -51,6 +54,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
}
+ test("params") {
+ ParamsSuite.checkParams(new GBTClassifier)
+ val model = new GBTClassificationModel("gbtc",
+ Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
+ Array(1.0))
+ ParamsSuite.checkParams(model)
+ }
+
test("Binary classification with continuous features: Log Loss") {
val categoricalFeatures = Map.empty[Int, Int]
testCombinations.foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a755cac3ea..5a6265ea99 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -18,8 +18,9 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
@@ -62,6 +63,12 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("params") {
+ ParamsSuite.checkParams(new LogisticRegression)
+ val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0)
+ ParamsSuite.checkParams(model)
+ }
+
test("logistic regression: default params") {
val lr = new LogisticRegression
assert(lr.getLabelCol === "label")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 1d04ccb509..75cf5bd4ea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -19,15 +19,18 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.Metadata
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -52,6 +55,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
dataset = sqlContext.createDataFrame(rdd)
}
+ test("params") {
+ ParamsSuite.checkParams(new OneVsRest)
+ val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0)
+ val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel))
+ ParamsSuite.checkParams(model)
+ }
+
test("one-vs-rest: default params") {
val numClasses = 3
val ova = new OneVsRest()
@@ -102,6 +112,26 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = ovr.fit(dataset).transform(dataset)
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}
+
+ test("OneVsRest.copy and OneVsRestModel.copy") {
+ val lr = new LogisticRegression()
+ .setMaxIter(1)
+
+ val ovr = new OneVsRest()
+ withClue("copy with classifier unset should work") {
+ ovr.copy(ParamMap(lr.maxIter -> 10))
+ }
+ ovr.setClassifier(lr)
+ val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10))
+ require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects")
+ require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10,
+ "copy should handle extra classifier params")
+
+ val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1))
+ ovrModel.models.foreach { case m: LogisticRegressionModel =>
+ require(m.getThreshold === 0.1, "copy should handle extra model params")
+ }
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index eee9355a67..1b6b69c7dc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -27,7 +29,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* Test suite for [[RandomForestClassifier]].
*/
@@ -62,6 +63,13 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
}
+ test("params") {
+ ParamsSuite.checkParams(new RandomForestClassifier)
+ val model = new RandomForestClassificationModel("rfc",
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
+ ParamsSuite.checkParams(model)
+ }
+
test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val rf = new RandomForestClassifier()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
new file mode 100644
index 0000000000..def869fe66
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+
+class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
+
+ test("params") {
+ ParamsSuite.checkParams(new BinaryClassificationEvaluator)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index 36a1ac6b79..aa722da323 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -18,12 +18,17 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new RegressionEvaluator)
+ }
+
test("Regression Evaluator: default params") {
/**
* Here is the instruction describing how to export the test data into CSV format
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 7953bd0417..2086043983 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -30,6 +31,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
}
+ test("params") {
+ ParamsSuite.checkParams(new Binarizer)
+ }
+
test("Binarize continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 507a8a7db2..ec85e0d151 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import scala.util.Random
import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -27,6 +28,10 @@ import org.apache.spark.sql.{DataFrame, Row}
class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new Bucketizer)
+ }
+
test("Bucket continuous features, without -inf,inf") {
// Check a set of valid feature values.
val splits = Array(-0.5, 0.0, 0.5)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 7b2d70e644..4157b84b29 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -28,8 +28,7 @@ import org.apache.spark.util.Utils
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
- val hashingTF = new HashingTF
- ParamsSuite.checkParams(hashingTF, 3)
+ ParamsSuite.checkParams(new HashingTF)
}
test("hashingTF") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index d83772e8be..08f80af034 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -38,6 +40,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("params") {
+ ParamsSuite.checkParams(new IDF)
+ val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0)))
+ ParamsSuite.checkParams(model)
+ }
+
test("compute IDF with default parameter") {
val numOfFeatures = 4
val data = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 2e5036a844..65846a846b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
@@ -36,6 +37,10 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
indexer.transform(df)
}
+ test("params") {
+ ParamsSuite.checkParams(new OneHotEncoder)
+ }
+
test("OneHotEncoder dropLast = false") {
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index feca866cd7..29eebd8960 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.ml.param.ParamsSuite
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkFunSuite
@@ -27,6 +28,10 @@ import org.apache.spark.sql.Row
class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new PolynomialExpansion)
+ }
+
test("Polynomial expansion with default parameter") {
val data = Array(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 5f557e16e5..99f82bea42 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -19,10 +19,17 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new StringIndexer)
+ val model = new StringIndexerModel("indexer", Array("a", "b"))
+ ParamsSuite.checkParams(model)
+ }
+
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index ac279cb321..e5fd21c3f6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -20,15 +20,27 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
+class TokenizerSuite extends SparkFunSuite {
+
+ test("params") {
+ ParamsSuite.checkParams(new Tokenizer)
+ }
+}
+
class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._
+ test("params") {
+ ParamsSuite.checkParams(new RegexTokenizer)
+ }
+
test("RegexTokenizer") {
val tokenizer0 = new RegexTokenizer()
.setGaps(false)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index 489abb5af7..bb4d5b983e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
@@ -26,6 +27,10 @@ import org.apache.spark.sql.functions.col
class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new VectorAssembler)
+ }
+
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 06affc7305..8c85c96d5c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -21,6 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
@@ -91,6 +92,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
private def getIndexer: VectorIndexer =
new VectorIndexer().setInputCol("features").setOutputCol("indexed")
+ test("params") {
+ ParamsSuite.checkParams(new VectorIndexer)
+ val model = new VectorIndexerModel("indexer", 1, Map.empty)
+ ParamsSuite.checkParams(model)
+ }
+
test("Cannot fit an empty DataFrame") {
val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
val vectorIndexer = getIndexer
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 94ebc3aebf..aa6ce533fd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -18,13 +18,21 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new Word2Vec)
+ val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f))))
+ ParamsSuite.checkParams(model)
+ }
+
test("Word2Vec") {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 96094d7a09..050d4170ea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -205,19 +205,27 @@ class ParamsSuite extends SparkFunSuite {
object ParamsSuite extends SparkFunSuite {
/**
- * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
- * by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as
- * the param method name.
+ * Checks common requirements for [[Params.params]]:
+ * - params are ordered by names
+ * - param parent has the same UID as the object's UID
+ * - param name is the same as the param method name
+ * - obj.copy should return the same type as the obj
*/
- def checkParams(obj: Params, expectedNumParams: Int): Unit = {
+ def checkParams(obj: Params): Unit = {
+ val clazz = obj.getClass
+
val params = obj.params
- require(params.length === expectedNumParams,
- s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.")
val paramNames = params.map(_.name)
- require(paramNames === paramNames.sorted)
+ require(paramNames === paramNames.sorted, "params must be ordered by names")
params.foreach { p =>
assert(p.parent === obj.uid)
assert(obj.getParam(p.name) === p)
+ // TODO: Check that setters return self, which needs special handling for generic types.
}
+
+ val copyMethod = clazz.getMethod("copy", classOf[ParamMap])
+ val copyReturnType = copyMethod.getReturnType
+ require(copyReturnType === obj.getClass,
+ s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index a9e78366ad..2759248344 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -38,7 +38,5 @@ class TestParams(override val uid: String) extends Params with HasMaxIter with H
require(isDefined(inputCol))
}
- override def copy(extra: ParamMap): TestParams = {
- super.copy(extra).asInstanceOf[TestParams]
- }
+ override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
index eb5408d3fe..b3af81a3c6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
@@ -18,13 +18,15 @@
package org.apache.spark.ml.param.shared
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.param.Params
+import org.apache.spark.ml.param.{ParamMap, Params}
class SharedParamsSuite extends SparkFunSuite {
test("outputCol") {
- class Obj(override val uid: String) extends Params with HasOutputCol
+ class Obj(override val uid: String) extends Params with HasOutputCol {
+ override def copy(extra: ParamMap): Obj = defaultCopy(extra)
+ }
val obj = new Obj("obj")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 6fef0b6205..dcfe9f7e27 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -96,6 +96,8 @@ object CrossValidatorSuite {
override def transformSchema(schema: StructType): StructType = {
throw new UnsupportedOperationException
}
+
+ override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
}
class MyEvaluator extends Evaluator {
@@ -105,5 +107,7 @@ object CrossValidatorSuite {
}
override val uid: String = "eval"
+
+ override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
}
}