diff options
Diffstat (limited to 'mllib/src/test')
3 files changed, 17 insertions, 13 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index a3a8f65eac..dd3f4c6e53 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -138,16 +138,18 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead new LDA().setTopicConcentration(-1.1) } - // validateParams() - lda.validateParams() + val dummyDF = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features") + // validate parameters + lda.transformSchema(dummyDF.schema) lda.setDocConcentration(1.1) - lda.validateParams() + lda.transformSchema(dummyDF.schema) lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray) - lda.validateParams() + lda.transformSchema(dummyDF.schema) lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray) withClue("LDA docConcentration validity check failed for bad array length") { intercept[IllegalArgumentException] { - lda.validateParams() + lda.transformSchema(dummyDF.schema) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 035bfc07b6..87206c777e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -57,13 +57,15 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De test("MinMaxScaler arguments max must be larger than min") { withClue("arguments max must be larger than min") { + val dummyDF = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature") intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(10).setMax(0) - scaler.validateParams() + val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature") + scaler.transformSchema(dummyDF.schema) } intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(0).setMax(0) - scaler.validateParams() + val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("feature") + scaler.transformSchema(dummyDF.schema) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 94191e5df3..6bb4678dc5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -21,21 +21,21 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { - val slicer = new VectorSlicer + val slicer = new VectorSlicer().setInputCol("feature") ParamsSuite.checkParams(slicer) assert(slicer.getIndices.length === 0) assert(slicer.getNames.length === 0) withClue("VectorSlicer should not have any features selected by default") { intercept[IllegalArgumentException] { - slicer.validateParams() + slicer.transformSchema(StructType(Seq(StructField("feature", new VectorUDT, true)))) } } } |