aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-04 13:30:17 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-04 13:30:17 -0800
commitba5f81859d6ba37a228a1c43d26c47e64c0382cd (patch)
tree375e9042ebf42c3a915186e1e21de6c650436b83 /mllib/src/test/scala/org/apache
parent0171b71e9511cef512e96a759e407207037f3c49 (diff)
downloadspark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.tar.gz
spark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.tar.bz2
spark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.zip
[SPARK-11259][ML] Params.validateParams() should be called automatically
See JIRA: https://issues.apache.org/jira/browse/SPARK-11259 Author: Yanbo Liang <ybliang8@gmail.com> Closes #9224 from yanboliang/spark-11259.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala23
1 files changed, 22 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 8c86767456..f3321fb5a1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -26,9 +26,10 @@ import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite
-import org.apache.spark.ml.feature.HashingTF
+import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler}
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util._
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -174,6 +175,26 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
}
}
+
+ test("pipeline validateParams") {
+ val df = sqlContext.createDataFrame(
+ Seq(
+ (1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
+ (2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
+ (3, Vectors.dense(1.0, 0.0, 5.0), 3.0),
+ (4, Vectors.dense(0.0, 0.0, 5.0), 4.0))
+ ).toDF("id", "features", "label")
+
+ intercept[IllegalArgumentException] {
+ val scaler = new MinMaxScaler()
+ .setInputCol("features")
+ .setOutputCol("features_scaled")
+ .setMin(10)
+ .setMax(0)
+ val pipeline = new Pipeline().setStages(Array(scaler))
+ pipeline.fit(df)
+ }
+ }
}