aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala90
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala6
-rw-r--r--project/MimaExcludes.scala30
-rw-r--r--python/pyspark/ml/util.py40
16 files changed, 144 insertions, 107 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index f406f8c426..38176b96ba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -46,6 +46,10 @@ abstract class PipelineStage extends Params with Logging {
*
* Check transform validity and derive the output schema from the input schema.
*
+ * We check validity for interactions between parameters during `transformSchema` and
+ * raise an exception if any parameter value is invalid. Parameter value checks which
+ * do not depend on other parameters are handled by `Param.validate()`.
+ *
* Typical implementation should first conduct verification on schema change and parameter
* validity, including complex parameter interaction checks.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 52f93f5a6b..ca52231333 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -203,6 +203,12 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
+ /**
+ * Number of trees in ensemble
+ */
+ @Since("2.0.0")
+ val getNumTrees: Int = trees.length
+
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index fe29926e0d..41b84f4816 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -40,7 +40,7 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
-import org.apache.spark.sql.types.DoubleType
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils
@@ -176,8 +176,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
}
}
- override def validateParams(): Unit = {
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
checkThresholdConsistency()
+ super.validateAndTransformSchema(schema, fitting, featuresDataType)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 907c73e2e4..d151213f9e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -158,7 +158,7 @@ class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
- with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
+ with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with MLWritable with Serializable {
require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
@@ -221,15 +221,6 @@ class RandomForestClassificationModel private[ml] (
}
}
- /**
- * Number of trees in ensemble
- *
- * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
- */
- // TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
- @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
- val numTrees: Int = trees.length
-
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 653fa41124..7cd0f159c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -216,13 +216,6 @@ final class ChiSqSelectorModel private[ml] (
@Since("1.6.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
- /**
- * @group setParam
- */
- @Since("1.6.0")
- @deprecated("labelCol is not used by ChiSqSelectorModel.", "2.0.0")
- def setLabelCol(value: String): this.type = set(labelCol, value)
-
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 96206e0b7a..5bd8ebe098 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -547,21 +547,6 @@ trait Params extends Identifiable with Serializable {
}
/**
- * Validates parameter values stored internally.
- * Raise an exception if any parameter value is invalid.
- *
- * This only needs to check for interactions between parameters.
- * Parameter value checks which do not depend on other parameters are handled by
- * `Param.validate()`. This method does not handle input/output column parameters;
- * those are checked during schema validation.
- * @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
- */
- @deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
- def validateParams(): Unit = {
- // Do nothing by default. Override to handle Param interactions.
- }
-
- /**
* Explains a param.
* @param param input param, must belong to this instance.
* @return a string that contains the input param name, doc, and optionally its default value and
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index ed2d05525d..6d8159aa3b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -183,6 +183,12 @@ class GBTRegressionModel private[ml](
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
+ /**
+ * Number of trees in ensemble
+ */
+ @Since("2.0.0")
+ val getNumTrees: Int = trees.length
+
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index eb4e38cc83..19ddf36a71 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -611,9 +611,6 @@ class LinearRegressionSummary private[regression] (
private val privateModel: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {
- @deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0")
- val model: LinearRegressionModel = privateModel
-
@transient private val metrics = new RegressionMetrics(
predictions
.select(col(predictionCol), col(labelCol).cast(DoubleType))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index d60f05eed5..90d89c51c5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -145,7 +145,7 @@ class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
- with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {
require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
@@ -182,14 +182,6 @@ class RandomForestRegressionModel private[ml] (
_trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
}
- /**
- * Number of trees in ensemble
- * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
- */
- // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
- @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
- val numTrees: Int = trees.length
-
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index d3cbc36379..0d6e9034e5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -95,11 +95,6 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
/** Trees in this ensemble. Warning: These have null parent Estimators. */
def trees: Array[M]
- /**
- * Number of trees in ensemble
- */
- val getNumTrees: Int = trees.length
-
/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 40510ad804..83ab4b5da8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -319,8 +319,32 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
}
}
-/** Used for [[RandomForestParams]] */
-private[ml] trait HasFeatureSubsetStrategy extends Params {
+/**
+ * Parameters for Random Forest algorithms.
+ */
+private[ml] trait RandomForestParams extends TreeEnsembleParams {
+
+ /**
+ * Number of trees to train (>= 1).
+ * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ *
+ * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
+ * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
+ * are a bit different.
+ * @group param
+ */
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+ ParamValidators.gtEq(1))
+
+ setDefault(numTrees -> 20)
+
+ /** @group setParam */
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
+
+ /** @group getParam */
+ final def getNumTrees: Int = $(numTrees)
/**
* The number of features to consider for splits at each tree node.
@@ -366,38 +390,6 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
}
-/**
- * Used for [[RandomForestParams]].
- * This is separated out from [[RandomForestParams]] because of an issue with the
- * `numTrees` method conflicting with this Param in the Estimator.
- */
-private[ml] trait HasNumTrees extends Params {
-
- /**
- * Number of trees to train (>= 1).
- * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
- * TODO: Change to always do bootstrapping (simpler). SPARK-7130
- * (default = 20)
- * @group param
- */
- final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
- ParamValidators.gtEq(1))
-
- setDefault(numTrees -> 20)
-
- /** @group setParam */
- def setNumTrees(value: Int): this.type = set(numTrees, value)
-
- /** @group getParam */
- final def getNumTrees: Int = $(numTrees)
-}
-
-/**
- * Parameters for Random Forest algorithms.
- */
-private[ml] trait RandomForestParams extends TreeEnsembleParams
- with HasFeatureSubsetStrategy with HasNumTrees
-
private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
@@ -407,21 +399,15 @@ private[spark] object RandomForestParams {
private[ml] trait RandomForestClassifierParams
extends RandomForestParams with TreeClassifierParams
-private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
- with HasFeatureSubsetStrategy with TreeClassifierParams
-
private[ml] trait RandomForestRegressorParams
extends RandomForestParams with TreeRegressorParams
-private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
- with HasFeatureSubsetStrategy with TreeRegressorParams
-
/**
* Parameters for Gradient-Boosted Tree algorithms.
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize {
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
/* TODO: Add this doc when we add this param. SPARK-7132
* Threshold for stopping early when runWithValidation is used.
@@ -434,24 +420,26 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
// final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
// validationTol -> 1e-5
- setDefault(maxIter -> 20, stepSize -> 0.1)
-
/** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)
/**
- * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
- * estimator.
+ * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking
+ * the contribution of each estimator.
* (default = 0.1)
- * @group setParam
+ * @group param
*/
+ final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " +
+ "(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+ /** @group getParam */
+ final def getStepSize: Double = $(stepSize)
+
+ /** @group setParam */
def setStepSize(value: Double): this.type = set(stepSize, value)
- override def validateParams(): Unit = {
- require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)(
- getStepSize), "GBT parameter stepSize should be in interval (0, 1], " +
- s"but it given invalid value $getStepSize.")
- }
+ setDefault(maxIter -> 20, stepSize -> 0.1)
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
private[ml] def getOldBoostingStrategy(
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 5b7e5ec75c..bbb9886391 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -46,7 +46,7 @@ private[util] sealed trait BaseReadWrite {
* Sets the Spark SQLContext to use for saving/loading.
*/
@Since("1.6.0")
- @deprecated("Use session instead", "2.0.0")
+ @deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0")
def context(sqlContext: SQLContext): this.type = {
optionSparkSession = Option(sqlContext.sparkSession)
this
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 3492709677..7c36745ab2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -70,6 +70,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
ParamsSuite.checkParams(model)
}
+ test("GBT parameter stepSize should be in interval (0, 1]") {
+ withClue("GBT parameter stepSize should be in interval (0, 1]") {
+ intercept[IllegalArgumentException] {
+ new GBTClassifier().setStepSize(10)
+ }
+ }
+ }
+
test("Binary classification with continuous features: Log Loss") {
val categoricalFeatures = Map.empty[Int, Int]
testCombinations.foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index e360542eae..9c4c59a5e6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -194,6 +194,12 @@ class LogisticRegressionSuite
// thresholds and threshold must be consistent: values
withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
intercept[IllegalArgumentException] {
+ lr2.fit(smallBinaryDataset,
+ lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
+ }
+ }
+ withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
+ intercept[IllegalArgumentException] {
val lr2model = lr2.fit(smallBinaryDataset,
lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
lr2model.getThreshold
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 12f7ed202b..84014014f2 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -867,6 +867,36 @@ object MimaExcludes {
// [SPARK-12221] Add CPU time to metrics
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this")
+ ) ++ Seq(
+ // [SPARK-18481] ML 2.1 QA: Remove deprecated methods for ML
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PipelineStage.validateParams"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.JavaParams.validateParams"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.Params.validateParams"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.validateParams"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegression.validateParams"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.validateParams"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.validateParams"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.setLabelCol"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.validateParams"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.validateParams"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.validateParams"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.model"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassifier"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassifier"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassificationModel"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressor"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressor"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressionModel"),
+ ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.getNumTrees"),
+ ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.getNumTrees"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy")
)
}
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 7d39c30122..bec4b28952 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -78,7 +78,14 @@ class MLWriter(object):
raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
def context(self, sqlContext):
- """Sets the SQL context to use for saving."""
+ """
+ Sets the SQL context to use for saving.
+ .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead.
+ """
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+ def session(self, sparkSession):
+ """Sets the Spark Session to use for saving."""
raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
@@ -105,10 +112,19 @@ class JavaMLWriter(MLWriter):
return self
def context(self, sqlContext):
- """Sets the SQL context to use for saving."""
+ """
+ Sets the SQL context to use for saving.
+ .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead.
+ """
+ warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.")
self._jwrite.context(sqlContext._ssql_ctx)
return self
+ def session(self, sparkSession):
+ """Sets the Spark Session to use for saving."""
+ self._jwrite.session(sparkSession._jsparkSession)
+ return self
+
@inherit_doc
class MLWritable(object):
@@ -155,7 +171,14 @@ class MLReader(object):
raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
def context(self, sqlContext):
- """Sets the SQL context to use for loading."""
+ """
+ Sets the SQL context to use for loading.
+ .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead.
+ """
+ raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
+
+ def session(self, sparkSession):
+ """Sets the Spark Session to use for loading."""
raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
@@ -180,10 +203,19 @@ class JavaMLReader(MLReader):
return self._clazz._from_java(java_obj)
def context(self, sqlContext):
- """Sets the SQL context to use for loading."""
+ """
+ Sets the SQL context to use for loading.
+ .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead.
+ """
+ warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.")
self._jread.context(sqlContext._ssql_ctx)
return self
+ def session(self, sparkSession):
+ """Sets the Spark Session to use for loading."""
+ self._jread.session(sparkSession._jsparkSession)
+ return self
+
@classmethod
def _java_loader_class(cls, clazz):
"""