diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-01-04 13:30:17 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-04 13:30:17 -0800 |
commit | ba5f81859d6ba37a228a1c43d26c47e64c0382cd (patch) | |
tree | 375e9042ebf42c3a915186e1e21de6c650436b83 /mllib/src/test/scala/org/apache | |
parent | 0171b71e9511cef512e96a759e407207037f3c49 (diff) | |
download | spark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.tar.gz spark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.tar.bz2 spark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.zip |
[SPARK-11259][ML] Params.validateParams() should be called automatically
See JIRA: https://issues.apache.org/jira/browse/SPARK-11259
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #9224 from yanboliang/spark-11259.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 8c86767456..f3321fb5a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,9 +26,10 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite -import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -174,6 +175,26 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } } + + test("pipeline validateParams") { + val df = sqlContext.createDataFrame( + Seq( + (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + ).toDF("id", "features", "label") + + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("features_scaled") + .setMin(10) + .setMax(0) + val pipeline = new Pipeline().setStages(Array(scaler)) + pipeline.fit(df) + } + } } |