aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala13
1 files changed, 8 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 461905c127..4249ff5c1e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
@@ -42,10 +42,12 @@ import org.apache.spark.sql.DataFrame
* It supports both continuous and categorical features.
*/
@AlphaComponent
-final class GBTRegressor
+final class GBTRegressor(override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
with GBTParams with TreeRegressorParams with Logging {
+ def this() = this(Identifiable.randomUID("gbtr"))
+
// Override parameter setters from parent trait for Java API compatibility.
// Parameters from TreeRegressorParams:
@@ -149,7 +151,7 @@ object GBTRegressor {
*/
@AlphaComponent
final class GBTRegressionModel(
- override val parent: GBTRegressor,
+ override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double])
extends PredictionModel[Vector, GBTRegressionModel]
@@ -173,7 +175,7 @@ final class GBTRegressionModel(
}
override def copy(extra: ParamMap): GBTRegressionModel = {
- copyValues(new GBTRegressionModel(parent, _trees, _treeWeights), extra)
+ copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra)
}
override def toString: String = {
@@ -199,6 +201,7 @@ private[ml] object GBTRegressionModel {
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new GBTRegressionModel(parent, newTrees, oldModel.treeWeights)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
+ new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights)
}
}