aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala90
1 files changed, 84 insertions, 6 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 2ae74a2090..f3e8fd11b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
@@ -30,12 +30,12 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.Metadata
-class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
+class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
@transient var rdd: RDD[LabeledPoint] = _
override def beforeAll(): Unit = {
@@ -74,7 +74,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
// copied model must have the same parent.
MLTestingUtils.checkCopy(ovaModel)
- assert(ovaModel.models.size === numClasses)
+ assert(ovaModel.models.length === numClasses)
val transformedDataset = ovaModel.transform(dataset)
// check for label metadata in prediction col
@@ -160,6 +160,84 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
require(m.getThreshold === 0.1, "copy should handle extra model params")
}
}
+
+ test("read/write: OneVsRest") {
+ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+
+ val ova = new OneVsRest()
+ .setClassifier(lr)
+ .setLabelCol("myLabel")
+ .setFeaturesCol("myFeature")
+ .setPredictionCol("myPrediction")
+
+ val ova2 = testDefaultReadWrite(ova, testParams = false)
+ assert(ova.uid === ova2.uid)
+ assert(ova.getFeaturesCol === ova2.getFeaturesCol)
+ assert(ova.getLabelCol === ova2.getLabelCol)
+ assert(ova.getPredictionCol === ova2.getPredictionCol)
+
+ ova2.getClassifier match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ assert(lr.getRegParam === lr2.getRegParam)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRest expected classifier of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+ }
+
+ test("read/write: OneVsRestModel") {
+ def checkModelData(model: OneVsRestModel, model2: OneVsRestModel): Unit = {
+ assert(model.uid === model2.uid)
+ assert(model.getFeaturesCol === model2.getFeaturesCol)
+ assert(model.getLabelCol === model2.getLabelCol)
+ assert(model.getPredictionCol === model2.getPredictionCol)
+
+ val classifier = model.getClassifier.asInstanceOf[LogisticRegression]
+
+ model2.getClassifier match {
+ case lr2: LogisticRegression =>
+ assert(classifier.uid === lr2.uid)
+ assert(classifier.getMaxIter === lr2.getMaxIter)
+ assert(classifier.getRegParam === lr2.getRegParam)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ assert(model.labelMetadata === model2.labelMetadata)
+ model.models.zip(model2.models).foreach {
+ case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) =>
+ assert(lrModel1.uid === lrModel2.uid)
+ assert(lrModel1.coefficients === lrModel2.coefficients)
+ assert(lrModel1.intercept === lrModel2.intercept)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected model of type" +
+ s" LogisticRegressionModel but found ${other.getClass.getName}")
+ }
+ }
+
+ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+ val ova = new OneVsRest().setClassifier(lr)
+ val ovaModel = ova.fit(dataset)
+ val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false)
+ checkModelData(ovaModel, newOvaModel)
+ }
+
+ test("should support all NumericType labels and not support other types") {
+ val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
+ MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
+ ovr, isClassification = true, sqlContext) { (expected, actual) =>
+ val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
+ val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
+ assert(expectedModels.length === actualModels.length)
+ expectedModels.zip(actualModels).foreach { case (e, a) =>
+ assert(e.intercept === a.intercept)
+ assert(e.coefficients.toArray === a.coefficients.toArray)
+ }
+ }
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
@@ -168,7 +246,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid
setMaxIter(1)
- override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
+ override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
val labelSchema = dataset.schema($(labelCol))
// check for label attribute propagation.
assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))