diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index d0e3fe7ad1..89afb94b0f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -17,6 +17,86 @@ package org.apache.spark.ml.classification +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset} + +class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("extractLabeledPoints") { + def getTestData(labels: Seq[Double]): DataFrame = { + val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } + sqlContext.createDataFrame(data) + } + + val c = new MockClassifier + // Valid dataset + val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0)) + c.extractLabeledPoints(df0, 6).count() + // Invalid datasets + val df1 = getTestData(Seq(0.0, -2.0, 1.0, 5.0)) + withClue("Classifier should fail if label is negative") { + val e: SparkException = intercept[SparkException] { + c.extractLabeledPoints(df1, 6).count() + } + assert(e.getMessage.contains("given dataset with invalid label")) + } + val df2 = getTestData(Seq(0.0, 2.1, 1.0, 5.0)) + withClue("Classifier should fail if label is not an integer") { + val e: SparkException = intercept[SparkException] { + c.extractLabeledPoints(df2, 6).count() + } + assert(e.getMessage.contains("given dataset with invalid label")) + } + // extractLabeledPoints with numClasses specified + withClue("Classifier should fail if label is >= numClasses") { + val e: SparkException = intercept[SparkException] { + c.extractLabeledPoints(df0, numClasses = 5).count() + } + assert(e.getMessage.contains("given dataset with invalid label")) + } + withClue("Classifier.extractLabeledPoints should fail if numClasses <= 0") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + c.extractLabeledPoints(df0, numClasses = 0).count() + } + assert(e.getMessage.contains("but requires numClasses > 0")) + } + } + + test("getNumClasses") { + def getTestData(labels: Seq[Double]): DataFrame = { + val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } + sqlContext.createDataFrame(data) + } + + val c = new MockClassifier + // Valid dataset + val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0)) + assert(c.getNumClasses(df0) === 6) + // Invalid datasets + val df1 = getTestData(Seq(0.0, 2.0, 1.0, 5.1)) + withClue("getNumClasses should fail if label is max label not an integer") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + c.getNumClasses(df1) + } + assert(e.getMessage.contains("requires integers in range")) + } + val df2 = getTestData(Seq(0.0, 2.0, 1.0, Int.MaxValue.toDouble)) + withClue("getNumClasses should fail if label is max label is >= Int.MaxValue") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + c.getNumClasses(df2) + } + assert(e.getMessage.contains("requires integers in range")) + } + } +} + object ClassifierSuite { /** @@ -29,4 +109,32 @@ object ClassifierSuite { "rawPredictionCol" -> "myRawPrediction" ) + class MockClassifier(override val uid: String) + extends Classifier[Vector, MockClassifier, MockClassificationModel] { + + def this() = this(Identifiable.randomUID("mockclassifier")) + + override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError() + + override def train(dataset: Dataset[_]): MockClassificationModel = + throw new NotImplementedError() + + // Make methods public + override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = + super.extractLabeledPoints(dataset, numClasses) + def getNumClasses(dataset: Dataset[_]): Int = super.getNumClasses(dataset) + } + + class MockClassificationModel(override val uid: String) + extends ClassificationModel[Vector, MockClassificationModel] { + + def this() = this(Identifiable.randomUID("mockclassificationmodel")) + + protected def predictRaw(features: Vector): Vector = throw new NotImplementedError() + + override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError() + + override def numClasses: Int = throw new NotImplementedError() + } + } |