aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala86
1 files changed, 84 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index d290cc9b06..8108460518 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -17,14 +17,96 @@
package org.apache.spark.ml.util
-import org.apache.spark.ml.Model
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
-object MLTestingUtils {
+object MLTestingUtils extends SparkFunSuite {
def checkCopy(model: Model[_]): Unit = {
val copied = model.copy(ParamMap.empty)
.asInstanceOf[Model[_]]
assert(copied.parent.uid == model.parent.uid)
assert(copied.parent == model.parent)
}
+
+ def checkNumericTypes[M <: Model[M], T <: Estimator[M]](
+ estimator: T,
+ isClassification: Boolean,
+ sqlContext: SQLContext)(check: (M, M) => Unit): Unit = {
+ val dfs = if (isClassification) {
+ genClassifDFWithNumericLabelCol(sqlContext)
+ } else {
+ genRegressionDFWithNumericLabelCol(sqlContext)
+ }
+ val expected = estimator.fit(dfs(DoubleType))
+ val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t)))
+ actuals.foreach(actual => check(expected, actual))
+
+ val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext)
+ val thrown = intercept[IllegalArgumentException] {
+ estimator.fit(dfWithStringLabels)
+ }
+ assert(thrown.getMessage contains
+ "Column label must be of type NumericType but was actually of type StringType")
+ }
+
+ def genClassifDFWithNumericLabelCol(
+ sqlContext: SQLContext,
+ labelColName: String = "label",
+ featuresColName: String = "features"): Map[NumericType, DataFrame] = {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, Vectors.dense(0, 2, 3)),
+ (1, Vectors.dense(0, 3, 1)),
+ (0, Vectors.dense(0, 2, 2)),
+ (1, Vectors.dense(0, 3, 9)),
+ (0, Vectors.dense(0, 2, 6))
+ )).toDF(labelColName, featuresColName)
+
+ val types =
+ Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+ types.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName)))
+ .map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) }
+ .toMap
+ }
+
+ def genRegressionDFWithNumericLabelCol(
+ sqlContext: SQLContext,
+ labelColName: String = "label",
+ featuresColName: String = "features",
+ censorColName: String = "censor"): Map[NumericType, DataFrame] = {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, Vectors.dense(0)),
+ (1, Vectors.dense(1)),
+ (2, Vectors.dense(2)),
+ (3, Vectors.dense(3)),
+ (4, Vectors.dense(4))
+ )).toDF(labelColName, featuresColName)
+
+ val types =
+ Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+ types
+ .map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName)))
+ .map { case (t, d) =>
+ t -> TreeTests.setMetadata(d, 0, labelColName).withColumn(censorColName, lit(0.0))
+ }
+ .toMap
+ }
+
+ def generateDFWithStringLabelCol(
+ sqlContext: SQLContext,
+ labelColName: String = "label",
+ featuresColName: String = "features",
+ censorColName: String = "censor"): DataFrame =
+ sqlContext.createDataFrame(Seq(
+ ("0", Vectors.dense(0, 2, 3), 0.0),
+ ("1", Vectors.dense(0, 3, 1), 1.0),
+ ("0", Vectors.dense(0, 2, 2), 0.0),
+ ("1", Vectors.dense(0, 3, 9), 1.0),
+ ("0", Vectors.dense(0, 2, 6), 0.0)
+ )).toDF(labelColName, featuresColName, censorColName)
}