diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala | 86 |
1 files changed, 84 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index d290cc9b06..8108460518 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -17,14 +17,96 @@ package org.apache.spark.ml.util -import org.apache.spark.ml.Model +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ -object MLTestingUtils { +object MLTestingUtils extends SparkFunSuite { def checkCopy(model: Model[_]): Unit = { val copied = model.copy(ParamMap.empty) .asInstanceOf[Model[_]] assert(copied.parent.uid == model.parent.uid) assert(copied.parent == model.parent) } + + def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( + estimator: T, + isClassification: Boolean, + sqlContext: SQLContext)(check: (M, M) => Unit): Unit = { + val dfs = if (isClassification) { + genClassifDFWithNumericLabelCol(sqlContext) + } else { + genRegressionDFWithNumericLabelCol(sqlContext) + } + val expected = estimator.fit(dfs(DoubleType)) + val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t))) + actuals.foreach(actual => check(expected, actual)) + + val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(dfWithStringLabels) + } + assert(thrown.getMessage contains + "Column label must be of type NumericType but was actually of type StringType") + } + + def genClassifDFWithNumericLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features"): Map[NumericType, DataFrame] = { + val df = sqlContext.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 1)), + (0, Vectors.dense(0, 2, 2)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF(labelColName, featuresColName) + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) + .map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) } + .toMap + } + + def genRegressionDFWithNumericLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features", + censorColName: String = "censor"): Map[NumericType, DataFrame] = { + val df = sqlContext.createDataFrame(Seq( + (0, Vectors.dense(0)), + (1, Vectors.dense(1)), + (2, Vectors.dense(2)), + (3, Vectors.dense(3)), + (4, Vectors.dense(4)) + )).toDF(labelColName, featuresColName) + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types + .map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) + .map { case (t, d) => + t -> TreeTests.setMetadata(d, 0, labelColName).withColumn(censorColName, lit(0.0)) + } + .toMap + } + + def generateDFWithStringLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features", + censorColName: String = "censor"): DataFrame = + sqlContext.createDataFrame(Seq( + ("0", Vectors.dense(0, 2, 3), 0.0), + ("1", Vectors.dense(0, 3, 1), 1.0), + ("0", Vectors.dense(0, 2, 2), 0.0), + ("1", Vectors.dense(0, 3, 9), 1.0), + ("0", Vectors.dense(0, 2, 6), 0.0) + )).toDF(labelColName, featuresColName, censorColName) } |