aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-03-16 17:31:55 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-16 17:31:55 -0700
commit92b70576eabf8ff94ac476e2b3c66f8b3d28e79e (patch)
treeb15286aade54722a14fa9325e965803316a531f7 /mllib/src/test
parentd4d84936fb82bee91f4b04608de9f75c293ccc9e (diff)
downloadspark-92b70576eabf8ff94ac476e2b3c66f8b3d28e79e.tar.gz
spark-92b70576eabf8ff94ac476e2b3c66f8b3d28e79e.tar.bz2
spark-92b70576eabf8ff94ac476e2b3c66f8b3d28e79e.zip
[SPARK-13761][ML] Deprecate validateParams
## What changes were proposed in this pull request? Deprecate validateParams() method here: https://github.com/apache/spark/blob/035d3acdf3c1be5b309a861d5c5beb803b946b5e/mllib/src/main/scala/org/apache/spark/ml/param/params.scala#L553 Move all functionality in overridden methods to transformSchema(). Check docs to make sure they indicate complex Param interaction checks should be done in transformSchema. ## How was this patch tested? unit tests Author: Yuhao Yang <hhbyyh@gmail.com> Closes #11620 from hhbyyh/depreValid.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala8
3 files changed, 17 insertions, 13 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index a3a8f65eac..dd3f4c6e53 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -138,16 +138,18 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
new LDA().setTopicConcentration(-1.1)
}
- // validateParams()
- lda.validateParams()
+ val dummyDF = sqlContext.createDataFrame(Seq(
+ (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features")
+ // validate parameters
+ lda.transformSchema(dummyDF.schema)
lda.setDocConcentration(1.1)
- lda.validateParams()
+ lda.transformSchema(dummyDF.schema)
lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray)
- lda.validateParams()
+ lda.transformSchema(dummyDF.schema)
lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray)
withClue("LDA docConcentration validity check failed for bad array length") {
intercept[IllegalArgumentException] {
- lda.validateParams()
+ lda.transformSchema(dummyDF.schema)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 035bfc07b6..87206c777e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -57,13 +57,15 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
test("MinMaxScaler arguments max must be larger than min") {
withClue("arguments max must be larger than min") {
+ val dummyDF = sqlContext.createDataFrame(Seq(
+ (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature")
intercept[IllegalArgumentException] {
- val scaler = new MinMaxScaler().setMin(10).setMax(0)
- scaler.validateParams()
+ val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature")
+ scaler.transformSchema(dummyDF.schema)
}
intercept[IllegalArgumentException] {
- val scaler = new MinMaxScaler().setMin(0).setMax(0)
- scaler.validateParams()
+ val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("feature")
+ scaler.transformSchema(dummyDF.schema)
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index 94191e5df3..6bb4678dc5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -21,21 +21,21 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType}
class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
- val slicer = new VectorSlicer
+ val slicer = new VectorSlicer().setInputCol("feature")
ParamsSuite.checkParams(slicer)
assert(slicer.getIndices.length === 0)
assert(slicer.getNames.length === 0)
withClue("VectorSlicer should not have any features selected by default") {
intercept[IllegalArgumentException] {
- slicer.validateParams()
+ slicer.transformSchema(StructType(Seq(StructField("feature", new VectorUDT, true))))
}
}
}