aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala108
1 files changed, 108 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
index d0e3fe7ad1..89afb94b0f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
@@ -17,6 +17,86 @@
package org.apache.spark.ml.classification
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("extractLabeledPoints") {
+ def getTestData(labels: Seq[Double]): DataFrame = {
+ val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
+ sqlContext.createDataFrame(data)
+ }
+
+ val c = new MockClassifier
+ // Valid dataset
+ val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
+ c.extractLabeledPoints(df0, 6).count()
+ // Invalid datasets
+ val df1 = getTestData(Seq(0.0, -2.0, 1.0, 5.0))
+ withClue("Classifier should fail if label is negative") {
+ val e: SparkException = intercept[SparkException] {
+ c.extractLabeledPoints(df1, 6).count()
+ }
+ assert(e.getMessage.contains("given dataset with invalid label"))
+ }
+ val df2 = getTestData(Seq(0.0, 2.1, 1.0, 5.0))
+ withClue("Classifier should fail if label is not an integer") {
+ val e: SparkException = intercept[SparkException] {
+ c.extractLabeledPoints(df2, 6).count()
+ }
+ assert(e.getMessage.contains("given dataset with invalid label"))
+ }
+ // extractLabeledPoints with numClasses specified
+ withClue("Classifier should fail if label is >= numClasses") {
+ val e: SparkException = intercept[SparkException] {
+ c.extractLabeledPoints(df0, numClasses = 5).count()
+ }
+ assert(e.getMessage.contains("given dataset with invalid label"))
+ }
+ withClue("Classifier.extractLabeledPoints should fail if numClasses <= 0") {
+ val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+ c.extractLabeledPoints(df0, numClasses = 0).count()
+ }
+ assert(e.getMessage.contains("but requires numClasses > 0"))
+ }
+ }
+
+ test("getNumClasses") {
+ def getTestData(labels: Seq[Double]): DataFrame = {
+ val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
+ sqlContext.createDataFrame(data)
+ }
+
+ val c = new MockClassifier
+ // Valid dataset
+ val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
+ assert(c.getNumClasses(df0) === 6)
+ // Invalid datasets
+ val df1 = getTestData(Seq(0.0, 2.0, 1.0, 5.1))
+ withClue("getNumClasses should fail if label is max label not an integer") {
+ val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+ c.getNumClasses(df1)
+ }
+ assert(e.getMessage.contains("requires integers in range"))
+ }
+ val df2 = getTestData(Seq(0.0, 2.0, 1.0, Int.MaxValue.toDouble))
+ withClue("getNumClasses should fail if label is max label is >= Int.MaxValue") {
+ val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+ c.getNumClasses(df2)
+ }
+ assert(e.getMessage.contains("requires integers in range"))
+ }
+ }
+}
+
object ClassifierSuite {
/**
@@ -29,4 +109,32 @@ object ClassifierSuite {
"rawPredictionCol" -> "myRawPrediction"
)
+ class MockClassifier(override val uid: String)
+ extends Classifier[Vector, MockClassifier, MockClassificationModel] {
+
+ def this() = this(Identifiable.randomUID("mockclassifier"))
+
+ override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError()
+
+ override def train(dataset: Dataset[_]): MockClassificationModel =
+ throw new NotImplementedError()
+
+ // Make methods public
+ override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] =
+ super.extractLabeledPoints(dataset, numClasses)
+ def getNumClasses(dataset: Dataset[_]): Int = super.getNumClasses(dataset)
+ }
+
+ class MockClassificationModel(override val uid: String)
+ extends ClassificationModel[Vector, MockClassificationModel] {
+
+ def this() = this(Identifiable.randomUID("mockclassificationmodel"))
+
+ protected def predictRaw(features: Vector): Vector = throw new NotImplementedError()
+
+ override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError()
+
+ override def numClasses: Int = throw new NotImplementedError()
+ }
+
}