aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Model.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala27
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala115
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java52
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala40
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala39
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala2
45 files changed, 413 insertions, 198 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index 9974efe7b1..7fd515369b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -32,7 +32,15 @@ abstract class Model[M <: Model[M]] extends Transformer {
* The parent estimator that produced this model.
* Note: For ensembles' component Models, this value can be null.
*/
- val parent: Estimator[M]
+ var parent: Estimator[M] = _
+
+ /**
+ * Sets the parent of this model (Java API).
+ */
+ def setParent(parent: Estimator[M]): M = {
+ this.parent = parent
+ this.asInstanceOf[M]
+ }
override def copy(extra: ParamMap): M = {
// The default implementation of Params.copy doesn't work for models.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 33d430f567..fac54188f9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -80,7 +81,9 @@ abstract class PipelineStage extends Params with Logging {
* an identity transformer.
*/
@AlphaComponent
-class Pipeline extends Estimator[PipelineModel] {
+class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
+
+ def this() = this(Identifiable.randomUID("pipeline"))
/**
* param for pipeline stages
@@ -148,7 +151,7 @@ class Pipeline extends Estimator[PipelineModel] {
}
}
- new PipelineModel(this, transformers.toArray)
+ new PipelineModel(uid, transformers.toArray).setParent(this)
}
override def copy(extra: ParamMap): Pipeline = {
@@ -171,7 +174,7 @@ class Pipeline extends Estimator[PipelineModel] {
*/
@AlphaComponent
class PipelineModel private[ml] (
- override val parent: Pipeline,
+ override val uid: String,
val stages: Array[Transformer])
extends Model[PipelineModel] with Logging {
@@ -190,6 +193,6 @@ class PipelineModel private[ml] (
}
override def copy(extra: ParamMap): PipelineModel = {
- new PipelineModel(parent, stages)
+ new PipelineModel(uid, stages)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index f6a5f27425..ec0f76aa66 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -88,7 +88,7 @@ abstract class Predictor[
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)
- copyValues(train(dataset))
+ copyValues(train(dataset).setParent(this))
}
override def copy(extra: ParamMap): Learner = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index dcebea1d4b..7c961332bf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
@@ -39,10 +39,12 @@ import org.apache.spark.sql.DataFrame
* features.
*/
@AlphaComponent
-final class DecisionTreeClassifier
+final class DecisionTreeClassifier(override val uid: String)
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {
+ def this() = this(Identifiable.randomUID("dtc"))
+
// Override parameter setters from parent trait for Java API compatibility.
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
@@ -101,7 +103,7 @@ object DecisionTreeClassifier {
*/
@AlphaComponent
final class DecisionTreeClassificationModel private[ml] (
- override val parent: DecisionTreeClassifier,
+ override val uid: String,
override val rootNode: Node)
extends PredictionModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
@@ -114,7 +116,7 @@ final class DecisionTreeClassificationModel private[ml] (
}
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
- copyValues(new DecisionTreeClassificationModel(parent, rootNode), extra)
+ copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
}
override def toString: String = {
@@ -138,6 +140,7 @@ private[ml] object DecisionTreeClassificationModel {
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
- new DecisionTreeClassificationModel(parent, rootNode)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
+ new DecisionTreeClassificationModel(uid, rootNode)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index ae51b05a0c..d504d84beb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
@@ -44,10 +44,12 @@ import org.apache.spark.sql.DataFrame
* Note: Multiclass labels are not currently supported.
*/
@AlphaComponent
-final class GBTClassifier
+final class GBTClassifier(override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
with GBTParams with TreeClassifierParams with Logging {
+ def this() = this(Identifiable.randomUID("gbtc"))
+
// Override parameter setters from parent trait for Java API compatibility.
// Parameters from TreeClassifierParams:
@@ -160,7 +162,7 @@ object GBTClassifier {
*/
@AlphaComponent
final class GBTClassificationModel(
- override val parent: GBTClassifier,
+ override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double])
extends PredictionModel[Vector, GBTClassificationModel]
@@ -184,7 +186,7 @@ final class GBTClassificationModel(
}
override def copy(extra: ParamMap): GBTClassificationModel = {
- copyValues(new GBTClassificationModel(parent, _trees, _treeWeights), extra)
+ copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra)
}
override def toString: String = {
@@ -210,6 +212,7 @@ private[ml] object GBTClassificationModel {
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new GBTClassificationModel(parent, newTrees, oldModel.treeWeights)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
+ new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 2b10362687..8694c96e4c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -26,6 +26,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction}
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint
@@ -50,10 +51,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
* Currently, this class only supports binary classification.
*/
@AlphaComponent
-class LogisticRegression
+class LogisticRegression(override val uid: String)
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams with Logging {
+ def this() = this(Identifiable.randomUID("logreg"))
+
/**
* Set the regularization parameter.
* Default is 0.0.
@@ -213,7 +216,7 @@ class LogisticRegression
(weightsWithIntercept, 0.0)
}
- new LogisticRegressionModel(this, weights.compressed, intercept)
+ new LogisticRegressionModel(uid, weights.compressed, intercept)
}
}
@@ -224,7 +227,7 @@ class LogisticRegression
*/
@AlphaComponent
class LogisticRegressionModel private[ml] (
- override val parent: LogisticRegression,
+ override val uid: String,
val weights: Vector,
val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
@@ -276,7 +279,7 @@ class LogisticRegressionModel private[ml] (
}
override def copy(extra: ParamMap): LogisticRegressionModel = {
- copyValues(new LogisticRegressionModel(parent, weights, intercept), extra)
+ copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
}
override protected def raw2prediction(rawPrediction: Vector): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index afb8d75d57..1543f051cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -25,7 +25,7 @@ import org.apache.spark.annotation.{AlphaComponent, Experimental}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.Param
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
@@ -40,19 +40,17 @@ private[ml] trait OneVsRestParams extends PredictorParams {
type ClassifierType = Classifier[F, E, M] forSome {
type F
type M <: ClassificationModel[F, M]
- type E <: Classifier[F, E, M]
+ type E <: Classifier[F, E, M]
}
/**
* param for the base binary classifier that we reduce multiclass classification into.
* @group param
*/
- val classifier: Param[ClassifierType] =
- new Param(this, "classifier", "base binary classifier ")
+ val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier")
/** @group getParam */
def getClassifier: ClassifierType = $(classifier)
-
}
/**
@@ -70,10 +68,10 @@ private[ml] trait OneVsRestParams extends PredictorParams {
* (taking label 0).
*/
@AlphaComponent
-class OneVsRestModel private[ml] (
- override val parent: OneVsRest,
- labelMetadata: Metadata,
- val models: Array[_ <: ClassificationModel[_,_]])
+final class OneVsRestModel private[ml] (
+ override val uid: String,
+ labelMetadata: Metadata,
+ val models: Array[_ <: ClassificationModel[_,_]])
extends Model[OneVsRestModel] with OneVsRestParams {
override def transformSchema(schema: StructType): StructType = {
@@ -145,11 +143,13 @@ class OneVsRestModel private[ml] (
* is picked to label the example.
*/
@Experimental
-final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
+final class OneVsRest(override val uid: String)
+ extends Estimator[OneVsRestModel] with OneVsRestParams {
+
+ def this() = this(Identifiable.randomUID("oneVsRest"))
/** @group setParam */
- def setClassifier(value: Classifier[_,_,_]): this.type = {
- // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed
+ def setClassifier(value: Classifier[_, _, _]): this.type = {
set(classifier, value.asInstanceOf[ClassifierType])
}
@@ -204,6 +204,7 @@ final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
case attr: Attribute => attr
}
- copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))
+ val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
+ copyValues(model)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 9954893f14..a1de791985 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
@@ -41,10 +41,12 @@ import org.apache.spark.sql.DataFrame
* features.
*/
@AlphaComponent
-final class RandomForestClassifier
+final class RandomForestClassifier(override val uid: String)
extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
+ def this() = this(Identifiable.randomUID("rfc"))
+
// Override parameter setters from parent trait for Java API compatibility.
// Parameters from TreeClassifierParams:
@@ -118,7 +120,7 @@ object RandomForestClassifier {
*/
@AlphaComponent
final class RandomForestClassificationModel private[ml] (
- override val parent: RandomForestClassifier,
+ override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel])
extends PredictionModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
@@ -146,7 +148,7 @@ final class RandomForestClassificationModel private[ml] (
}
override def copy(extra: ParamMap): RandomForestClassificationModel = {
- copyValues(new RandomForestClassificationModel(parent, _trees), extra)
+ copyValues(new RandomForestClassificationModel(uid, _trees), extra)
}
override def toString: String = {
@@ -172,6 +174,7 @@ private[ml] object RandomForestClassificationModel {
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestClassificationModel(parent, newTrees)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
+ new RandomForestClassificationModel(uid, newTrees)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index e5a73c6087..c1af09c969 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
@@ -33,7 +33,10 @@ import org.apache.spark.sql.types.DoubleType
* Evaluator for binary classification, which expects two input columns: score and label.
*/
@AlphaComponent
-class BinaryClassificationEvaluator extends Evaluator with HasRawPredictionCol with HasLabelCol {
+class BinaryClassificationEvaluator(override val uid: String)
+ extends Evaluator with HasRawPredictionCol with HasLabelCol {
+
+ def this() = this(Identifiable.randomUID("binEval"))
/**
* param for metric name in evaluation
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 6eb1db6971..62f4a63434 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
@@ -32,7 +32,10 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
* Binarize a column of continuous features given a threshold.
*/
@AlphaComponent
-final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
+final class Binarizer(override val uid: String)
+ extends Transformer with HasInputCol with HasOutputCol {
+
+ def this() = this(Identifiable.randomUID("binarizer"))
/**
* Param for threshold used to binarize continuous features.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index d8f1961cb3..ac8dfb5632 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -21,11 +21,11 @@ import java.{util => ju}
import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.Model
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
@@ -35,10 +35,10 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
*/
@AlphaComponent
-final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
+final class Bucketizer(override val uid: String)
extends Model[Bucketizer] with HasInputCol with HasOutputCol {
- def this() = this(null)
+ def this() = this(Identifiable.randomUID("bucketizer"))
/**
* Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index f8b56293e3..8b32eee0e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
@@ -31,7 +32,10 @@ import org.apache.spark.sql.types.DataType
* multiplier.
*/
@AlphaComponent
-class ElementwiseProduct extends UnaryTransformer[Vector, Vector, ElementwiseProduct] {
+class ElementwiseProduct(override val uid: String)
+ extends UnaryTransformer[Vector, Vector, ElementwiseProduct] {
+
+ def this() = this(Identifiable.randomUID("elemProd"))
/**
* the vector to multiply with input vectors
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index c305a819a8..30033ced68 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{IntParam, ParamValidators}
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
@@ -29,7 +30,9 @@ import org.apache.spark.sql.types.DataType
* Maps a sequence of terms to their term frequencies using the hashing trick.
*/
@AlphaComponent
-class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
+class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
+
+ def this() = this(Identifiable.randomUID("hashingTF"))
/**
* Number of features. Should be > 0.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index d901a20aed..788c392050 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
@@ -62,7 +62,9 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
*/
@AlphaComponent
-final class IDF extends Estimator[IDFModel] with IDFBase {
+final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase {
+
+ def this() = this(Identifiable.randomUID("idf"))
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -74,7 +76,7 @@ final class IDF extends Estimator[IDFModel] with IDFBase {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
val idf = new feature.IDF($(minDocFreq)).fit(input)
- copyValues(new IDFModel(this, idf))
+ copyValues(new IDFModel(uid, idf).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -88,7 +90,7 @@ final class IDF extends Estimator[IDFModel] with IDFBase {
*/
@AlphaComponent
class IDFModel private[ml] (
- override val parent: IDF,
+ override val uid: String,
idfModel: feature.IDFModel)
extends Model[IDFModel] with IDFBase {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index 755b46a64c..3f689d1585 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
@@ -29,7 +30,9 @@ import org.apache.spark.sql.types.DataType
* Normalize a vector to have unit norm using the given p-norm.
*/
@AlphaComponent
-class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
+class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] {
+
+ def this() = this(Identifiable.randomUID("normalizer"))
/**
* Normalization in L^p^ space. Must be >= 1.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 46514ae5f0..1fb9b9ae75 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribu
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
@@ -37,8 +37,10 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
* linearly dependent because they sum up to one.
*/
@AlphaComponent
-class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
- with HasInputCol with HasOutputCol {
+class OneHotEncoder(override val uid: String)
+ extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol with HasOutputCol {
+
+ def this() = this(Identifiable.randomUID("oneHot"))
/**
* Whether to include a component in the encoded vectors for the first category, defaults to true.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index 9e6177ca27..41564410e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{IntParam, ParamValidators}
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.types.DataType
@@ -34,7 +35,10 @@ import org.apache.spark.sql.types.DataType
* `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`.
*/
@AlphaComponent
-class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
+class PolynomialExpansion(override val uid: String)
+ extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
+
+ def this() = this(Identifiable.randomUID("poly"))
/**
* The polynomial degree to expand, which should be >= 1. A value of 1 means no expansion.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 7cad59ff3f..5ccda15d87 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
@@ -55,7 +56,10 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
* statistics on the samples in the training set.
*/
@AlphaComponent
-class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
+class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel]
+ with StandardScalerParams {
+
+ def this() = this(Identifiable.randomUID("stdScal"))
setDefault(withMean -> false, withStd -> true)
@@ -76,7 +80,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
val scalerModel = scaler.fit(input)
- copyValues(new StandardScalerModel(this, scalerModel))
+ copyValues(new StandardScalerModel(uid, scalerModel).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -96,7 +100,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
*/
@AlphaComponent
class StandardScalerModel private[ml] (
- override val parent: StandardScaler,
+ override val uid: String,
scaler: feature.StandardScalerModel)
extends Model[StandardScalerModel] with StandardScalerParams {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 3d78537ad8..3f79b67309 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -23,6 +23,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{NumericType, StringType, StructType}
@@ -58,7 +59,10 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
* So the most frequent label gets index 0.
*/
@AlphaComponent
-class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase {
+class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
+ with StringIndexerBase {
+
+ def this() = this(Identifiable.randomUID("strIdx"))
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -73,7 +77,7 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
.map(_.getString(0))
.countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
- copyValues(new StringIndexerModel(this, labels))
+ copyValues(new StringIndexerModel(uid, labels).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -87,7 +91,7 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
*/
@AlphaComponent
class StringIndexerModel private[ml] (
- override val parent: StringIndexer,
+ override val uid: String,
labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
private val labelToIndex: OpenHashMap[String, Double] = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 649c217b16..36d9e17eca 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
/**
@@ -27,7 +28,9 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
* A tokenizer that converts the input string to lowercase and then splits it by white spaces.
*/
@AlphaComponent
-class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
+class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] {
+
+ def this() = this(Identifiable.randomUID("tok"))
override protected def createTransformFunc: String => Seq[String] = {
_.toLowerCase.split("\\s")
@@ -48,7 +51,10 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
* It returns an array of strings that can be empty.
*/
@AlphaComponent
-class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
+class RegexTokenizer(override val uid: String)
+ extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
+
+ def this() = this(Identifiable.randomUID("regexTok"))
/**
* Minimum token length, >= 0.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 796758a70e..1c00094769 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
@@ -33,7 +34,10 @@ import org.apache.spark.sql.types._
* A feature transformer that merges multiple columns into a vector column.
*/
@AlphaComponent
-class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
+class VectorAssembler(override val uid: String)
+ extends Transformer with HasInputCols with HasOutputCol {
+
+ def this() = this(Identifiable.randomUID("va"))
/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 0f83a29c86..6d1d0524e5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.callUDF
@@ -87,7 +87,10 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
* - Add option for allowing unknown categories.
*/
@AlphaComponent
-class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams {
+class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel]
+ with VectorIndexerParams {
+
+ def this() = this(Identifiable.randomUID("vecIdx"))
/** @group setParam */
def setMaxCategories(value: Int): this.type = set(maxCategories, value)
@@ -110,7 +113,9 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara
iter.foreach(localCatStats.addVector)
Iterator(localCatStats)
}.reduce((stats1, stats2) => stats1.merge(stats2))
- copyValues(new VectorIndexerModel(this, numFeatures, categoryStats.getCategoryMaps))
+ val model = new VectorIndexerModel(uid, numFeatures, categoryStats.getCategoryMaps)
+ .setParent(this)
+ copyValues(model)
}
override def transformSchema(schema: StructType): StructType = {
@@ -238,7 +243,7 @@ private object VectorIndexer {
*/
@AlphaComponent
class VectorIndexerModel private[ml] (
- override val parent: VectorIndexer,
+ override val uid: String,
val numFeatures: Int,
val categoryMaps: Map[Int, Map[Double, Int]])
extends Model[VectorIndexerModel] with VectorIndexerParams {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 34ff929701..8ace8c53bb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
@@ -85,7 +85,9 @@ private[feature] trait Word2VecBase extends Params
* natural language processing or machine learning process.
*/
@AlphaComponent
-final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
+final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase {
+
+ def this() = this(Identifiable.randomUID("w2v"))
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -122,7 +124,7 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
.setSeed($(seed))
.setVectorSize($(vectorSize))
.fit(input)
- copyValues(new Word2VecModel(this, wordVectors))
+ copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -136,7 +138,7 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
*/
@AlphaComponent
class Word2VecModel private[ml] (
- override val parent: Word2Vec,
+ override val uid: String,
wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 5a7ec29aac..247e08be1b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -40,12 +40,17 @@ import org.apache.spark.ml.util.Identifiable
* @tparam T param value type
*/
@AlphaComponent
-class Param[T] (val parent: Params, val name: String, val doc: String, val isValid: T => Boolean)
+class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
extends Serializable {
- def this(parent: Params, name: String, doc: String) =
+ def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
+ this(parent.uid, name, doc, isValid)
+
+ def this(parent: String, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue[T])
+ def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+
/**
* Assert that the given value is valid for this parameter.
*
@@ -60,8 +65,7 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
*/
private[param] def validate(value: T): Unit = {
if (!isValid(value)) {
- throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value." +
- s" Parameter description: $toString")
+ throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.")
}
}
@@ -75,19 +79,15 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
*/
def ->(value: T): ParamPair[T] = ParamPair(this, value)
- /**
- * Converts this param's name, doc, and optionally its default value and the user-supplied
- * value in its parent to string.
- */
- override def toString: String = {
- val valueStr = if (parent.isDefined(this)) {
- val defaultValueStr = parent.getDefault(this).map("default: " + _)
- val currentValueStr = parent.get(this).map("current: " + _)
- (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")")
- } else {
- "(undefined)"
+ override final def toString: String = s"${parent}__$name"
+
+ override final def hashCode: Int = toString.##
+
+ override final def equals(obj: Any): Boolean = {
+ obj match {
+ case p: Param[_] => (p.parent == parent) && (p.name == name)
+ case _ => false
}
- s"$name: $doc $valueStr"
}
}
@@ -173,49 +173,71 @@ object ParamValidators {
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
/** Specialized version of [[Param[Double]]] for Java. */
-class DoubleParam(parent: Params, name: String, doc: String, isValid: Double => Boolean)
+class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean)
extends Param[Double](parent, name, doc, isValid) {
- def this(parent: Params, name: String, doc: String) =
+ def this(parent: String, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)
+ def this(parent: Identifiable, name: String, doc: String, isValid: Double => Boolean) =
+ this(parent.uid, name, doc, isValid)
+
+ def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+
override def w(value: Double): ParamPair[Double] = super.w(value)
}
/** Specialized version of [[Param[Int]]] for Java. */
-class IntParam(parent: Params, name: String, doc: String, isValid: Int => Boolean)
+class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean)
extends Param[Int](parent, name, doc, isValid) {
- def this(parent: Params, name: String, doc: String) =
+ def this(parent: String, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)
+ def this(parent: Identifiable, name: String, doc: String, isValid: Int => Boolean) =
+ this(parent.uid, name, doc, isValid)
+
+ def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+
override def w(value: Int): ParamPair[Int] = super.w(value)
}
/** Specialized version of [[Param[Float]]] for Java. */
-class FloatParam(parent: Params, name: String, doc: String, isValid: Float => Boolean)
+class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean)
extends Param[Float](parent, name, doc, isValid) {
- def this(parent: Params, name: String, doc: String) =
+ def this(parent: String, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)
+ def this(parent: Identifiable, name: String, doc: String, isValid: Float => Boolean) =
+ this(parent.uid, name, doc, isValid)
+
+ def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+
override def w(value: Float): ParamPair[Float] = super.w(value)
}
/** Specialized version of [[Param[Long]]] for Java. */
-class LongParam(parent: Params, name: String, doc: String, isValid: Long => Boolean)
+class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean)
extends Param[Long](parent, name, doc, isValid) {
- def this(parent: Params, name: String, doc: String) =
+ def this(parent: String, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)
+ def this(parent: Identifiable, name: String, doc: String, isValid: Long => Boolean) =
+ this(parent.uid, name, doc, isValid)
+
+ def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+
override def w(value: Long): ParamPair[Long] = super.w(value)
}
/** Specialized version of [[Param[Boolean]]] for Java. */
-class BooleanParam(parent: Params, name: String, doc: String) // No need for isValid
+class BooleanParam(parent: String, name: String, doc: String) // No need for isValid
extends Param[Boolean](parent, name, doc) {
+ def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}
@@ -265,6 +287,9 @@ trait Params extends Identifiable with Serializable {
/**
* Returns all params sorted by their names. The default implementation uses Java reflection to
* list all public methods that have no arguments and return [[Param]].
+ *
+ * Note: Developer should not use this method in constructor because we cannot guarantee that
+ * this variable gets initialized before other params.
*/
lazy val params: Array[Param[_]] = {
val methods = this.getClass.getMethods
@@ -299,15 +324,36 @@ trait Params extends Identifiable with Serializable {
* those are checked during schema validation.
*/
def validateParams(): Unit = {
- params.filter(isDefined _).foreach { param =>
+ params.filter(isDefined).foreach { param =>
param.asInstanceOf[Param[Any]].validate($(param))
}
}
/**
- * Returns the documentation of all params.
+ * Explains a param.
+ * @param param input param, must belong to this instance.
+ * @return a string that contains the input param name, doc, and optionally its default value and
+ * the user-supplied value
+ */
+ def explainParam(param: Param[_]): String = {
+ shouldOwn(param)
+ val valueStr = if (isDefined(param)) {
+ val defaultValueStr = getDefault(param).map("default: " + _)
+ val currentValueStr = get(param).map("current: " + _)
+ (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")")
+ } else {
+ "(undefined)"
+ }
+ s"${param.name}: ${param.doc} $valueStr"
+ }
+
+ /**
+ * Explains all params of this instance.
+ * @see [[explainParam()]]
*/
- def explainParams(): String = params.mkString("\n")
+ def explainParams(): String = {
+ params.map(explainParam).mkString("\n")
+ }
/** Checks whether a param is explicitly set. */
final def isSet(param: Param[_]): Boolean = {
@@ -392,7 +438,6 @@ trait Params extends Identifiable with Serializable {
* @param value the default value
*/
protected final def setDefault[T](param: Param[T], value: T): this.type = {
- shouldOwn(param)
defaultParamMap.put(param, value)
this
}
@@ -430,13 +475,13 @@ trait Params extends Identifiable with Serializable {
}
/**
- * Creates a copy of this instance with a randomly generated uid and some extra params.
- * The default implementation calls the default constructor to create a new instance, then
- * copies the embedded and extra parameters over and returns the new instance.
+ * Creates a copy of this instance with the same UID and some extra params.
+ * The default implementation tries to create a new instance with the same UID.
+ * Then it copies the embedded and extra parameters over and returns the new instance.
* Subclasses should override this method if the default approach is not sufficient.
*/
def copy(extra: ParamMap): Params = {
- val that = this.getClass.newInstance()
+ val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
copyValues(that, extra)
that
}
@@ -465,7 +510,7 @@ trait Params extends Identifiable with Serializable {
/** Validates that the input param belongs to this instance. */
private def shouldOwn(param: Param[_]): Unit = {
- require(param.parent.eq(this), s"Param $param does not belong to $this.")
+ require(param.parent == uid && hasParam(param.name), s"Param $param does not belong to $this.")
}
/**
@@ -581,7 +626,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
override def toString: String = {
map.toSeq.sortBy(_._1.name).map { case (param, value) =>
- s"\t${param.parent.uid}-${param.name}: $value"
+ s"\t${param.parent}-${param.name}: $value"
}.mkString("{\n", ",\n", "\n}")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index d7cbffc3be..45c57b50da 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -35,6 +35,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
@@ -171,7 +172,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* Model fitted by ALS.
*/
class ALSModel private[ml] (
- override val parent: ALS,
+ override val uid: String,
k: Int,
userFactors: RDD[(Int, Array[Float])],
itemFactors: RDD[(Int, Array[Float])])
@@ -235,10 +236,12 @@ class ALSModel private[ml] (
* indicated user
* preferences rather than explicit ratings given to items.
*/
-class ALS extends Estimator[ALSModel] with ALSParams {
+class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
import org.apache.spark.ml.recommendation.ALS.Rating
+ def this() = this(Identifiable.randomUID("als"))
+
/** @group setParam */
def setRank(value: Int): this.type = set(rank, value)
@@ -303,7 +306,8 @@ class ALS extends Estimator[ALSModel] with ALSParams {
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative),
checkpointInterval = $(checkpointInterval), seed = $(seed))
- copyValues(new ALSModel(this, $(rank), userFactors, itemFactors))
+ val model = new ALSModel(uid, $(rank), userFactors, itemFactors).setParent(this)
+ copyValues(model)
}
override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index f8f0b161a4..e67df21b2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
@@ -38,10 +38,12 @@ import org.apache.spark.sql.DataFrame
* It supports both continuous and categorical features.
*/
@AlphaComponent
-final class DecisionTreeRegressor
+final class DecisionTreeRegressor(override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
with DecisionTreeParams with TreeRegressorParams {
+ def this() = this(Identifiable.randomUID("dtr"))
+
// Override parameter setters from parent trait for Java API compatibility.
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
@@ -91,7 +93,7 @@ object DecisionTreeRegressor {
*/
@AlphaComponent
final class DecisionTreeRegressionModel private[ml] (
- override val parent: DecisionTreeRegressor,
+ override val uid: String,
override val rootNode: Node)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
with DecisionTreeModel with Serializable {
@@ -104,7 +106,7 @@ final class DecisionTreeRegressionModel private[ml] (
}
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
- copyValues(new DecisionTreeRegressionModel(parent, rootNode), extra)
+ copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra)
}
override def toString: String = {
@@ -128,6 +130,7 @@ private[ml] object DecisionTreeRegressionModel {
s"Cannot convert non-regression DecisionTreeModel (old API) to" +
s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
- new DecisionTreeRegressionModel(parent, rootNode)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
+ new DecisionTreeRegressionModel(uid, rootNode)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 461905c127..4249ff5c1e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
@@ -42,10 +42,12 @@ import org.apache.spark.sql.DataFrame
* It supports both continuous and categorical features.
*/
@AlphaComponent
-final class GBTRegressor
+final class GBTRegressor(override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
with GBTParams with TreeRegressorParams with Logging {
+ def this() = this(Identifiable.randomUID("gbtr"))
+
// Override parameter setters from parent trait for Java API compatibility.
// Parameters from TreeRegressorParams:
@@ -149,7 +151,7 @@ object GBTRegressor {
*/
@AlphaComponent
final class GBTRegressionModel(
- override val parent: GBTRegressor,
+ override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double])
extends PredictionModel[Vector, GBTRegressionModel]
@@ -173,7 +175,7 @@ final class GBTRegressionModel(
}
override def copy(extra: ParamMap): GBTRegressionModel = {
- copyValues(new GBTRegressionModel(parent, _trees, _treeWeights), extra)
+ copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra)
}
override def toString: String = {
@@ -199,6 +201,7 @@ private[ml] object GBTRegressionModel {
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new GBTRegressionModel(parent, newTrees, oldModel.treeWeights)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
+ new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 36c242bb5f..3ebb78f792 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -20,14 +20,14 @@ package org.apache.spark.ml.regression
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
-import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
- OWLQN => BreezeOWLQN}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint
@@ -59,9 +59,12 @@ private[regression] trait LinearRegressionParams extends PredictorParams
* - L2 + L1 (elastic net)
*/
@AlphaComponent
-class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
+class LinearRegression(override val uid: String)
+ extends Regressor[Vector, LinearRegression, LinearRegressionModel]
with LinearRegressionParams with Logging {
+ def this() = this(Identifiable.randomUID("linReg"))
+
/**
* Set the regularization parameter.
* Default is 0.0.
@@ -128,7 +131,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
s"and the intercept will be the mean of the label; as a result, training is not needed.")
if (handlePersistence) instances.unpersist()
- return new LinearRegressionModel(this, Vectors.sparse(numFeatures, Seq()), yMean)
+ return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean)
}
val featuresMean = summarizer.mean.toArray
@@ -182,7 +185,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
if (handlePersistence) instances.unpersist()
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
- new LinearRegressionModel(this, weights.compressed, intercept)
+ copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
}
}
@@ -193,7 +196,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
*/
@AlphaComponent
class LinearRegressionModel private[ml] (
- override val parent: LinearRegression,
+ override val uid: String,
val weights: Vector,
val intercept: Double)
extends RegressionModel[Vector, LinearRegressionModel]
@@ -204,7 +207,7 @@ class LinearRegressionModel private[ml] (
}
override def copy(extra: ParamMap): LinearRegressionModel = {
- copyValues(new LinearRegressionModel(parent, weights, intercept), extra)
+ copyValues(new LinearRegressionModel(uid, weights, intercept), extra)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index dbc6289274..82437aa8de 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
@@ -37,10 +37,12 @@ import org.apache.spark.sql.DataFrame
* It supports both continuous and categorical features.
*/
@AlphaComponent
-final class RandomForestRegressor
+final class RandomForestRegressor(override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
with RandomForestParams with TreeRegressorParams {
+ def this() = this(Identifiable.randomUID("rfr"))
+
// Override parameter setters from parent trait for Java API compatibility.
// Parameters from TreeRegressorParams:
@@ -105,7 +107,7 @@ object RandomForestRegressor {
*/
@AlphaComponent
final class RandomForestRegressionModel private[ml] (
- override val parent: RandomForestRegressor,
+ override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel])
extends PredictionModel[Vector, RandomForestRegressionModel]
with TreeEnsembleModel with Serializable {
@@ -128,7 +130,7 @@ final class RandomForestRegressionModel private[ml] (
}
override def copy(extra: ParamMap): RandomForestRegressionModel = {
- copyValues(new RandomForestRegressionModel(parent, _trees), extra)
+ copyValues(new RandomForestRegressionModel(uid, _trees), extra)
}
override def toString: String = {
@@ -154,6 +156,6 @@ private[ml] object RandomForestRegressionModel {
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent, newTrees)
+ new RandomForestRegressionModel(parent.uid, newTrees)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index ac0d1fed84..5c6ff2dda3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -23,6 +23,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -81,7 +82,10 @@ private[ml] trait CrossValidatorParams extends Params {
* K-fold cross validation.
*/
@AlphaComponent
-class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging {
+class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
+ with CrossValidatorParams with Logging {
+
+ def this() = this(Identifiable.randomUID("cv"))
private val f2jBLAS = new F2jBLAS
@@ -136,7 +140,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
- copyValues(new CrossValidatorModel(this, bestModel))
+ copyValues(new CrossValidatorModel(uid, bestModel).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -150,7 +154,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
*/
@AlphaComponent
class CrossValidatorModel private[ml] (
- override val parent: CrossValidator,
+ override val uid: String,
val bestModel: Model[_])
extends Model[CrossValidatorModel] with CrossValidatorParams {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
index 8a56748ab0..1466976800 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
@@ -19,15 +19,24 @@ package org.apache.spark.ml.util
import java.util.UUID
+
/**
- * Object with a unique id.
+ * Trait for an object with an immutable unique ID that identifies itself and its derivatives.
*/
-private[ml] trait Identifiable extends Serializable {
+trait Identifiable {
+
+ /**
+ * An immutable unique ID for the object and its derivatives.
+ */
+ val uid: String
+}
+
+object Identifiable {
/**
- * A unique id for the object. The default implementation concatenates the class name, "_", and 8
- * random hex chars.
+ * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars.
*/
- private[ml] val uid: String =
- this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8)
+ def randomUID(prefix: String): String = {
+ prefix + "_" + UUID.randomUUID().toString.takeRight(12)
+ }
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 7e7189a2b1..f75e024a71 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -84,7 +84,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
.setThreshold(0.6)
.setProbabilityCol("myProbability");
LogisticRegressionModel model = lr.fit(dataset);
- LogisticRegression parent = model.parent();
+ LogisticRegression parent = (LogisticRegression) model.parent();
assert(parent.getMaxIter() == 10);
assert(parent.getRegParam() == 1.0);
assert(parent.getThreshold() == 0.6);
@@ -110,7 +110,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
- LogisticRegression parent2 = model2.parent();
+ LogisticRegression parent2 = (LogisticRegression) model2.parent();
assert(parent2.getMaxIter() == 5);
assert(parent2.getRegParam() == 0.1);
assert(parent2.getThreshold() == 0.4);
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 8abe575610..3a41890b92 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -21,43 +21,65 @@ import java.util.List;
import com.google.common.collect.Lists;
+import org.apache.spark.ml.util.Identifiable$;
+
/**
* A subclass of Params for testing.
*/
public class JavaTestParams extends JavaParams {
- public IntParam myIntParam;
+ public JavaTestParams() {
+ this.uid_ = Identifiable$.MODULE$.randomUID("javaTestParams");
+ init();
+ }
+
+ public JavaTestParams(String uid) {
+ this.uid_ = uid;
+ init();
+ }
+
+ private String uid_;
+
+ @Override
+ public String uid() {
+ return uid_;
+ }
- public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); }
+ private IntParam myIntParam_;
+ public IntParam myIntParam() { return myIntParam_; }
+
+ public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
public JavaTestParams setMyIntParam(int value) {
- set(myIntParam, value); return this;
+ set(myIntParam_, value); return this;
}
- public DoubleParam myDoubleParam;
+ private DoubleParam myDoubleParam_;
+ public DoubleParam myDoubleParam() { return myDoubleParam_; }
- public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); }
+ public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
public JavaTestParams setMyDoubleParam(double value) {
- set(myDoubleParam, value); return this;
+ set(myDoubleParam_, value); return this;
}
- public Param<String> myStringParam;
+ private Param<String> myStringParam_;
+ public Param<String> myStringParam() { return myStringParam_; }
- public String getMyStringParam() { return (String)getOrDefault(myStringParam); }
+ public String getMyStringParam() { return getOrDefault(myStringParam_); }
public JavaTestParams setMyStringParam(String value) {
- set(myStringParam, value); return this;
+ set(myStringParam_, value); return this;
}
- public JavaTestParams() {
- myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
- myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param",
+ private void init() {
+ myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
+ myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param",
ParamValidators.inRange(0.0, 1.0));
List<String> validStrings = Lists.newArrayList("a", "b");
- myStringParam = new Param<String>(this, "myStringParam", "this is a string param",
+ myStringParam_ = new Param<String>(this, "myStringParam", "this is a string param",
ParamValidators.inArray(validStrings));
- setDefault(myIntParam, 1);
- setDefault(myDoubleParam, 0.5);
+ setDefault(myIntParam_, 1);
+ setDefault(myDoubleParam_, 0.5);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index a82b86d560..d591a45686 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -77,14 +77,14 @@ public class JavaLinearRegressionSuite implements Serializable {
.setMaxIter(10)
.setRegParam(1.0);
LinearRegressionModel model = lr.fit(dataset);
- LinearRegression parent = model.parent();
+ LinearRegression parent = (LinearRegression) model.parent();
assertEquals(10, parent.getMaxIter());
assertEquals(1.0, parent.getRegParam(), 0.0);
// Call fit() with new params, and check as many params as we can.
LinearRegressionModel model2 =
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
- LinearRegression parent2 = model2.parent();
+ LinearRegression parent2 = (LinearRegression) model2.parent();
assertEquals(5, parent2.getMaxIter());
assertEquals(0.1, parent2.getRegParam(), 0.0);
assertEquals("thePred", model2.getPredictionCol());
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
new file mode 100644
index 0000000000..67c262d0f9
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.scalatest.FunSuite
+
+class IdentifiableSuite extends FunSuite {
+
+ import IdentifiableSuite.Test
+
+ test("Identifiable") {
+ val test0 = new Test("test_0")
+ assert(test0.uid === "test_0")
+
+ val test1 = new Test
+ assert(test1.uid.startsWith("test_"))
+ }
+}
+
+object IdentifiableSuite {
+
+ class Test(override val uid: String) extends Identifiable {
+ def this() = this(Identifiable.randomUID("test"))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 03af4ecd7a..3fdc66be8a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -268,7 +268,7 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite {
val newTree = dt.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
- oldTree, newTree.parent, categoricalFeatures)
+ oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 16c758b82c..ea86867f11 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -130,7 +130,7 @@ private object GBTClassifierSuite {
val newModel = gbt.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
val oldModelAsNew = GBTClassificationModel.fromOld(
- oldModel, newModel.parent, categoricalFeatures)
+ oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 4df8016009..43765241a2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.ml.classification
import org.scalatest.FunSuite
-import org.apache.spark.mllib.classification.LogisticRegressionSuite
+import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
@@ -37,8 +36,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
super.beforeAll()
sqlContext = new SQLContext(sc)
- dataset = sqlContext.createDataFrame(sc.parallelize(LogisticRegressionSuite
- .generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 4))
+ dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
/**
* Here is the instruction describing how to export the test data into CSV format
@@ -60,31 +58,30 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
val xMean = Array(5.843, 3.057, 3.758, 1.199)
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
- val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
- weights, xMean, xVariance, true, nPoints, 42)
+ val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
- sqlContext.createDataFrame(sc.parallelize(LogisticRegressionSuite
- .generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42), 4))
+ sqlContext.createDataFrame(
+ generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42))
}
}
test("logistic regression: default params") {
val lr = new LogisticRegression
- assert(lr.getLabelCol == "label")
- assert(lr.getFeaturesCol == "features")
- assert(lr.getPredictionCol == "prediction")
- assert(lr.getRawPredictionCol == "rawPrediction")
- assert(lr.getProbabilityCol == "probability")
- assert(lr.getFitIntercept == true)
+ assert(lr.getLabelCol === "label")
+ assert(lr.getFeaturesCol === "features")
+ assert(lr.getPredictionCol === "prediction")
+ assert(lr.getRawPredictionCol === "rawPrediction")
+ assert(lr.getProbabilityCol === "probability")
+ assert(lr.getFitIntercept)
val model = lr.fit(dataset)
model.transform(dataset)
.select("label", "probability", "prediction", "rawPrediction")
.collect()
assert(model.getThreshold === 0.5)
- assert(model.getFeaturesCol == "features")
- assert(model.getPredictionCol == "prediction")
- assert(model.getRawPredictionCol == "rawPrediction")
- assert(model.getProbabilityCol == "probability")
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.getRawPredictionCol === "rawPrediction")
+ assert(model.getProbabilityCol === "probability")
assert(model.intercept !== 0.0)
}
@@ -103,7 +100,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
.setThreshold(0.6)
.setProbabilityCol("myProbability")
val model = lr.fit(dataset)
- val parent = model.parent
+ val parent = model.parent.asInstanceOf[LogisticRegression]
assert(parent.getMaxIter === 10)
assert(parent.getRegParam === 1.0)
assert(parent.getThreshold === 0.6)
@@ -129,12 +126,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
// Call fit() with new params, and check as many params as we can.
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
lr.probabilityCol -> "theProb")
- val parent2 = model2.parent
+ val parent2 = model2.parent.asInstanceOf[LogisticRegression]
assert(parent2.getMaxIter === 5)
assert(parent2.getRegParam === 0.1)
assert(parent2.getThreshold === 0.4)
assert(model2.getThreshold === 0.4)
- assert(model2.getProbabilityCol == "theProb")
+ assert(model2.getProbabilityCol === "theProb")
}
test("logistic regression: Predictor, Classifier methods") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index e65ffae918..990cfb08af 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -57,7 +57,7 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
test("one-vs-rest: default params") {
val numClasses = 3
val ova = new OneVsRest()
- ova.setClassifier(new LogisticRegression)
+ .setClassifier(new LogisticRegression)
assert(ova.getLabelCol === "label")
assert(ova.getPredictionCol === "prediction")
val ovaModel = ova.fit(dataset)
@@ -97,7 +97,9 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
}
}
-private class MockLogisticRegression extends LogisticRegression {
+private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
+
+ def this() = this("mockLogReg")
setMaxIter(1)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index c41def9330..08f86fa45b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -160,7 +160,7 @@ private object RandomForestClassifierSuite {
val newModel = rf.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
val oldModelAsNew = RandomForestClassificationModel.fromOld(
- oldModel, newModel.parent, categoricalFeatures)
+ oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 6056e7d3f6..b96874f3a8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -23,21 +23,22 @@ class ParamsSuite extends FunSuite {
test("param") {
val solver = new TestParams()
+ val uid = solver.uid
import solver.{maxIter, inputCol}
assert(maxIter.name === "maxIter")
assert(maxIter.doc === "max number of iterations (>= 0)")
- assert(maxIter.parent.eq(solver))
- assert(maxIter.toString === "maxIter: max number of iterations (>= 0) (default: 10)")
+ assert(maxIter.parent === uid)
+ assert(maxIter.toString === s"${uid}__maxIter")
assert(!maxIter.isValid(-1))
assert(maxIter.isValid(0))
assert(maxIter.isValid(1))
solver.setMaxIter(5)
- assert(maxIter.toString ===
+ assert(solver.explainParam(maxIter) ===
"maxIter: max number of iterations (>= 0) (default: 10, current: 5)")
- assert(inputCol.toString === "inputCol: input column name (undefined)")
+ assert(inputCol.toString === s"${uid}__inputCol")
intercept[IllegalArgumentException] {
solver.setMaxIter(-1)
@@ -118,7 +119,10 @@ class ParamsSuite extends FunSuite {
assert(!solver.isDefined(inputCol))
intercept[NoSuchElementException](solver.getInputCol)
- assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n"))
+ assert(solver.explainParam(maxIter) ===
+ "maxIter: max number of iterations (>= 0) (default: 10, current: 100)")
+ assert(solver.explainParams() ===
+ Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n"))
assert(solver.getParam("inputCol").eq(inputCol))
assert(solver.getParam("maxIter").eq(maxIter))
@@ -148,7 +152,7 @@ class ParamsSuite extends FunSuite {
assert(!solver.isSet(maxIter))
val copied = solver.copy(ParamMap(solver.maxIter -> 50))
- assert(copied.uid !== solver.uid)
+ assert(copied.uid === solver.uid)
assert(copied.getInputCol === solver.getInputCol)
assert(copied.getMaxIter === 50)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index dc16073640..a9e78366ad 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -18,9 +18,12 @@
package org.apache.spark.ml.param
import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}
+import org.apache.spark.ml.util.Identifiable
/** A subclass of Params for testing. */
-class TestParams extends Params with HasMaxIter with HasInputCol {
+class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol {
+
+ def this() = this(Identifiable.randomUID("testParams"))
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 5aa81b44dd..1196a772df 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -85,7 +85,7 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite {
val newTree = dt.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
- oldTree, newTree.parent, categoricalFeatures)
+ oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 25b36ab08b..40e7e3273e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -130,7 +130,8 @@ private object GBTRegressorSuite {
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = gbt.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
- val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, categoricalFeatures)
+ val oldModelAsNew = GBTRegressionModel.fromOld(
+ oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 45f09f4fda..3efffbb763 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -116,7 +116,7 @@ private object RandomForestRegressorSuite extends FunSuite {
val newModel = rf.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
val oldModelAsNew = RandomForestRegressionModel.fromOld(
- oldModel, newModel.parent, categoricalFeatures)
+ oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
}
}