aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala3
11 files changed, 38 insertions, 15 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index f680d8d3c4..815f6fd997 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -59,7 +59,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new DecisionTreeClassifier)
- val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)
+ val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)
ParamsSuite.checkParams(model)
}
@@ -310,6 +310,7 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
dt: DecisionTreeClassifier,
categoricalFeatures: Map[Int, Int],
numClasses: Int): Unit = {
+ val numFeatures = data.first().features.size
val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
val oldTree = OldDecisionTree.train(data, oldStrategy)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
@@ -318,5 +319,6 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
+ assert(newTree.numFeatures === numFeatures)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index e3909bccaa..039141aeb6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -59,8 +59,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc",
- Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
- Array(1.0))
+ Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)),
+ Array(1.0), 1)
ParamsSuite.checkParams(model)
}
@@ -145,7 +145,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
}
-private object GBTClassifierSuite {
+private object GBTClassifierSuite extends SparkFunSuite {
/**
* Train 2 models on the given dataset, one using the old API and one using the new API.
@@ -156,6 +156,7 @@ private object GBTClassifierSuite {
validationData: Option[RDD[LabeledPoint]],
gbt: GBTClassifier,
categoricalFeatures: Map[Int, Int]): Unit = {
+ val numFeatures = data.first().features.size
val oldBoostingStrategy =
gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val oldGBT = new OldGBT(oldBoostingStrategy)
@@ -164,7 +165,9 @@ private object GBTClassifierSuite {
val newModel = gbt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTClassificationModel.fromOld(
- oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures)
+ oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
+ assert(newModel.numFeatures === numFeatures)
+ assert(oldModelAsNew.numFeatures === numFeatures)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index f5219f9f57..ec01998601 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -194,6 +194,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val model = lr.fit(dataset)
assert(model.numClasses === 2)
+ val numFeatures = dataset.select("features").first().getAs[Vector](0).size
+ assert(model.numFeatures === numFeatures)
val threshold = model.getThreshold
val results = model.transform(dataset)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index ddc948f65d..2d1df9b2b8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
@@ -73,6 +73,8 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp
.setSeed(11L)
.setMaxIter(numIterations)
val model = trainer.fit(dataFrame)
+ val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size
+ assert(model.numFeatures === numFeatures)
val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label")
.map { case Row(p: Double, l: Double) => (p, l) }
// train multinomial logistic regression
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index 8f50cb924e..fb5f00e064 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
final class TestProbabilisticClassificationModel(
override val uid: String,
+ override val numFeatures: Int,
override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] {
@@ -45,13 +46,14 @@ class ProbabilisticClassifierSuite extends SparkFunSuite {
test("test thresholding") {
val thresholds = Array(0.5, 0.2)
- val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds)
+ val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
+ .setThresholds(thresholds)
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
}
test("test thresholding not required") {
- val testModel = new TestProbabilisticClassificationModel("myuid", 2)
+ val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index b4403ec300..deb8ec771c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -68,7 +68,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
- Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2)
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
ParamsSuite.checkParams(model)
}
@@ -209,7 +209,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
*/
}
-private object RandomForestClassifierSuite {
+private object RandomForestClassifierSuite extends SparkFunSuite {
/**
* Train 2 models on the given dataset, one using the old API and one using the new API.
@@ -220,6 +220,7 @@ private object RandomForestClassifierSuite {
rf: RandomForestClassifier,
categoricalFeatures: Map[Int, Int],
numClasses: Int): Unit = {
+ val numFeatures = data.first().features.size
val oldStrategy =
rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity)
val oldModel = OldRandomForest.trainClassifier(
@@ -233,6 +234,7 @@ private object RandomForestClassifierSuite {
TreeTests.checkEqual(oldModelAsNew, newModel)
assert(newModel.hasParent)
assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
- assert(newModel.numClasses == numClasses)
+ assert(newModel.numClasses === numClasses)
+ assert(newModel.numFeatures === numFeatures)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index b092bcd6a7..868fb8eecb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -89,6 +89,7 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
data: RDD[LabeledPoint],
dt: DecisionTreeRegressor,
categoricalFeatures: Map[Int, Int]): Unit = {
+ val numFeatures = data.first().features.size
val oldStrategy = dt.getOldStrategy(categoricalFeatures)
val oldTree = OldDecisionTree.train(data, oldStrategy)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
@@ -97,5 +98,6 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
+ assert(newTree.numFeatures === numFeatures)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index a68197b591..09326600e6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -156,7 +156,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
}
-private object GBTRegressorSuite {
+private object GBTRegressorSuite extends SparkFunSuite {
/**
* Train 2 models on the given dataset, one using the old API and one using the new API.
@@ -167,6 +167,7 @@ private object GBTRegressorSuite {
validationData: Option[RDD[LabeledPoint]],
gbt: GBTRegressor,
categoricalFeatures: Map[Int, Int]): Unit = {
+ val numFeatures = data.first().features.size
val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val oldGBT = new OldGBT(oldBoostingStrategy)
val oldModel = oldGBT.run(data)
@@ -174,7 +175,9 @@ private object GBTRegressorSuite {
val newModel = gbt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTRegressionModel.fromOld(
- oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures)
+ oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures, numFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
+ assert(newModel.numFeatures === numFeatures)
+ assert(oldModelAsNew.numFeatures === numFeatures)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 8428f4f00b..7cb9471e69 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -22,8 +22,8 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
@@ -87,6 +87,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.getPredictionCol === "prediction")
assert(model.intercept !== 0.0)
assert(model.hasParent)
+ val numFeatures = dataset.select("features").first().getAs[Vector](0).size
+ assert(model.numFeatures === numFeatures)
}
test("linear regression with intercept without regularization") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 7b1b3f1148..7e751e4b55 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -137,6 +137,7 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
data: RDD[LabeledPoint],
rf: RandomForestRegressor,
categoricalFeatures: Map[Int, Int]): Unit = {
+ val numFeatures = data.first().features.size
val oldStrategy =
rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity)
val oldModel = OldRandomForest.trainRegressor(
@@ -147,5 +148,6 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
val oldModelAsNew = RandomForestRegressionModel.fromOld(
oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
+ assert(newModel.numFeatures === numFeatures)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index dc852795c7..d5c238e9ae 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -77,7 +77,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// Forest consisting of (full tree) + (internal node with 2 leafs)
val trees = Array(parent, grandParent).map { root =>
- new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel]
+ new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
+ .asInstanceOf[DecisionTreeModel]
}
val importances: Vector = RandomForest.featureImportances(trees, 2)
val tree2norm = feature0importance + feature1importance