aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala8
3 files changed, 17 insertions, 13 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index a3a8f65eac..dd3f4c6e53 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -138,16 +138,18 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
new LDA().setTopicConcentration(-1.1)
}
- // validateParams()
- lda.validateParams()
+ val dummyDF = sqlContext.createDataFrame(Seq(
+ (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features")
+ // validate parameters
+ lda.transformSchema(dummyDF.schema)
lda.setDocConcentration(1.1)
- lda.validateParams()
+ lda.transformSchema(dummyDF.schema)
lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray)
- lda.validateParams()
+ lda.transformSchema(dummyDF.schema)
lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray)
withClue("LDA docConcentration validity check failed for bad array length") {
intercept[IllegalArgumentException] {
- lda.validateParams()
+ lda.transformSchema(dummyDF.schema)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 035bfc07b6..87206c777e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -57,13 +57,15 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
test("MinMaxScaler arguments max must be larger than min") {
withClue("arguments max must be larger than min") {
+ val dummyDF = sqlContext.createDataFrame(Seq(
+ (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature")
intercept[IllegalArgumentException] {
- val scaler = new MinMaxScaler().setMin(10).setMax(0)
- scaler.validateParams()
+ val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature")
+ scaler.transformSchema(dummyDF.schema)
}
intercept[IllegalArgumentException] {
- val scaler = new MinMaxScaler().setMin(0).setMax(0)
- scaler.validateParams()
+ val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("feature")
+ scaler.transformSchema(dummyDF.schema)
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index 94191e5df3..6bb4678dc5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -21,21 +21,21 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType}
class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
- val slicer = new VectorSlicer
+ val slicer = new VectorSlicer().setInputCol("feature")
ParamsSuite.checkParams(slicer)
assert(slicer.getIndices.length === 0)
assert(slicer.getNames.length === 0)
withClue("VectorSlicer should not have any features selected by default") {
intercept[IllegalArgumentException] {
- slicer.validateParams()
+ slicer.transformSchema(StructType(Seq(StructField("feature", new VectorUDT, true))))
}
}
}