aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-04 11:28:59 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-04 11:28:59 -0700
commite0833c5958bbd73ff27cfe6865648d7b6e5a99bc (patch)
tree373883fa46f206ffcd34c4d0b67ce246b61bbc93
parent5a1a1075a607be683f008ef92fa227803370c45f (diff)
downloadspark-e0833c5958bbd73ff27cfe6865648d7b6e5a99bc.tar.gz
spark-e0833c5958bbd73ff27cfe6865648d7b6e5a99bc.tar.bz2
spark-e0833c5958bbd73ff27cfe6865648d7b6e5a99bc.zip
[SPARK-5956] [MLLIB] Pipeline components should be copyable.
This PR added `copy(extra: ParamMap): Params` to `Params`, which makes a copy of the current instance with a randomly generated uid and some extra param values. With this change, we only need to implement `fit` and `transform` without extra param values given the default implementation of `fit(dataset, extra)`: ~~~scala def fit(dataset: DataFrame, extra: ParamMap): Model = { copy(extra).fit(dataset) } ~~~ Inside `fit` and `transform`, since only the embedded values are used, I added `$` as an alias for `getOrDefault` to make the code easier to read. For example, in `LinearRegression.fit` we have: ~~~scala val effectiveRegParam = $(regParam) / yStd val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam ~~~ Meta-algorithm like `Pipeline` implements its own `copy(extra)`. So the fitted pipeline model stored all copied stages (no matter whether it is a transformer or a model). Other changes: * `Params$.inheritValues` is moved to `Params!.copyValues` and returns the target instance. * `fittingParamMap` was removed because the `parent` carries this information. * `validate` was renamed to `validateParams` to be more precise. TODOs: * [x] add tests for newly added methods * [ ] update documentation jkbradley dbtsai Author: Xiangrui Meng <meng@databricks.com> Closes #5820 from mengxr/SPARK-5956 and squashes the following commits: 7bef88d [Xiangrui Meng] address comments 05229c3 [Xiangrui Meng] assert -> assertEquals b2927b1 [Xiangrui Meng] organize imports f14456b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5956 93e7924 [Xiangrui Meng] add tests for hasParam & copy 463ecae [Xiangrui Meng] merge master 2b954c3 [Xiangrui Meng] update Binarizer 465dd12 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5956 282a1a8 [Xiangrui Meng] fix test 819dd2d [Xiangrui Meng] merge master b642872 [Xiangrui Meng] example code runs 5a67779 [Xiangrui Meng] examples compile c76b4d1 [Xiangrui Meng] fix all unit tests 0f4fd64 [Xiangrui Meng] fix some tests 9286a22 [Xiangrui Meng] copyValues to trained models 53e0973 [Xiangrui Meng] move inheritValues to Params and rename it to copyValues 9ee004e [Xiangrui Meng] merge copy and copyWith; rename validate to validateParams d882afc [Xiangrui Meng] test compile f082a31 [Xiangrui Meng] make Params copyable and simply handling of extra params in all spark.ml components
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java24
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala22
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala26
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Model.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala106
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala46
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala49
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala29
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala58
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala31
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala38
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala49
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala34
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala74
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala62
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala72
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala75
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala73
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala23
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala41
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala27
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala51
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java14
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala26
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala6
56 files changed, 671 insertions, 805 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index 46377a99c4..eac4f898a4 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -28,7 +28,6 @@ import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
-import org.apache.spark.ml.param.Params$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
@@ -129,16 +128,16 @@ class MyJavaLogisticRegression
// This method is used by fit().
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
- public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
+ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
// Extract columns from data using helper method.
- JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
+ JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();
// Do learning to estimate the weight vector.
int numFeatures = oldDataset.take(1).get(0).features().size();
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
// Create a model, and return it.
- return new MyJavaLogisticRegressionModel(this, paramMap, weights);
+ return new MyJavaLogisticRegressionModel(this, weights);
}
}
@@ -155,18 +154,11 @@ class MyJavaLogisticRegressionModel
private MyJavaLogisticRegression parent_;
public MyJavaLogisticRegression parent() { return parent_; }
- private ParamMap fittingParamMap_;
- public ParamMap fittingParamMap() { return fittingParamMap_; }
-
private Vector weights_;
public Vector weights() { return weights_; }
- public MyJavaLogisticRegressionModel(
- MyJavaLogisticRegression parent_,
- ParamMap fittingParamMap_,
- Vector weights_) {
+ public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) {
this.parent_ = parent_;
- this.fittingParamMap_ = fittingParamMap_;
this.weights_ = weights_;
}
@@ -210,10 +202,8 @@ class MyJavaLogisticRegressionModel
* In Java, we have to make this method public since Java does not understand Scala's protected
* modifier.
*/
- public MyJavaLogisticRegressionModel copy() {
- MyJavaLogisticRegressionModel m =
- new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
- Params$.MODULE$.inheritValues(this.extractParamMap(), this, m);
- return m;
+ @Override
+ public MyJavaLogisticRegressionModel copy(ParamMap extra) {
+ return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra);
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index 4e02acce69..29158d5c85 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -71,7 +71,7 @@ public class JavaSimpleParamsExample {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
- System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap());
+ System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());
// We may alternatively specify parameters using a ParamMap.
ParamMap paramMap = new ParamMap();
@@ -87,7 +87,7 @@ public class JavaSimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
- System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap());
+ System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
// Prepare test documents.
List<LabeledPoint> localTest = Lists.newArrayList(
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 9002e99d82..8340d91101 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -276,16 +276,14 @@ object DecisionTreeExample {
// Get the trained Decision Tree from the fitted PipelineModel
algo match {
case "classification" =>
- val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
- dt.asInstanceOf[DecisionTreeClassifier])
+ val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
}
case "regression" =>
- val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
- dt.asInstanceOf[DecisionTreeRegressor])
+ val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeRegressionModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 2245fa429f..2a2d067727 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -18,13 +18,12 @@
package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
-import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
+import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams}
+import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-
/**
* A simple example demonstrating how to write your own learning algorithm using Estimator,
* Transformer, and other abstractions.
@@ -99,7 +98,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
* class since the maxIter parameter is only used during training (not in the Model).
*/
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
- def getMaxIter: Int = getOrDefault(maxIter)
+ def getMaxIter: Int = $(maxIter)
}
/**
@@ -117,18 +116,16 @@ private class MyLogisticRegression
def setMaxIter(value: Int): this.type = set(maxIter, value)
// This method is used by fit()
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): MyLogisticRegressionModel = {
+ override protected def train(dataset: DataFrame): MyLogisticRegressionModel = {
// Extract columns from data using helper method.
- val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val oldDataset = extractLabeledPoints(dataset)
// Do learning to estimate the weight vector.
val numFeatures = oldDataset.take(1)(0).features.size
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
// Create a model, and return it.
- new MyLogisticRegressionModel(this, paramMap, weights)
+ new MyLogisticRegressionModel(this, weights)
}
}
@@ -139,7 +136,6 @@ private class MyLogisticRegression
*/
private class MyLogisticRegressionModel(
override val parent: MyLogisticRegression,
- override val fittingParamMap: ParamMap,
val weights: Vector)
extends ClassificationModel[Vector, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
@@ -176,9 +172,7 @@ private class MyLogisticRegressionModel(
*
* This is used for the default implementation of [[transform()]].
*/
- override protected def copy(): MyLogisticRegressionModel = {
- val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
- Params.inheritValues(extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): MyLogisticRegressionModel = {
+ copyValues(new MyLogisticRegressionModel(parent, weights), extra)
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
index 5fccb142d4..c5899b6683 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -201,14 +201,14 @@ object GBTExample {
// Get the trained GBT from the fitted PipelineModel
algo match {
case "classification" =>
- val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[GBTClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
- val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[GBTRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
index 9b909324ec..7f88d2681b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -209,16 +209,14 @@ object RandomForestExample {
// Get the trained Random Forest from the fitted PipelineModel
algo match {
case "classification" =>
- val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
- dt.asInstanceOf[RandomForestClassifier])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
- val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
- dt.asInstanceOf[RandomForestRegressor])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index bf805149d0..e8a991f50e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -63,7 +63,7 @@ object SimpleParamsExample {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
- println("Model 1 was fit using parameters: " + model1.fittingParamMap)
+ println("Model 1 was fit using parameters: " + model1.parent.extractParamMap())
// We may alternatively specify parameters using a ParamMap,
// which supports several methods for specifying parameters.
@@ -78,7 +78,7 @@ object SimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
val model2 = lr.fit(training.toDF(), paramMapCombined)
- println("Model 2 was fit using parameters: " + model2.fittingParamMap)
+ println("Model 2 was fit using parameters: " + model2.parent.extractParamMap())
// Prepare test data.
val test = sc.parallelize(Seq(
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index d6b3503ebd..7f3f3262a6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -34,13 +34,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* Fits a single model to the input data with optional parameters.
*
* @param dataset input dataset
- * @param paramPairs Optional list of param pairs.
- * These values override any specified in this Estimator's embedded ParamMap.
+ * @param firstParamPair the first param pair, overrides embedded params
+ * @param otherParamPairs other param pairs. These values override any specified in this
+ * Estimator's embedded ParamMap.
* @return fitted model
*/
@varargs
- def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = {
- val map = ParamMap(paramPairs: _*)
+ def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
+ val map = new ParamMap()
+ .put(firstParamPair)
+ .put(otherParamPairs: _*)
fit(dataset, map)
}
@@ -52,12 +55,19 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* These values override any specified in this Estimator's embedded ParamMap.
* @return fitted model
*/
- def fit(dataset: DataFrame, paramMap: ParamMap): M
+ def fit(dataset: DataFrame, paramMap: ParamMap): M = {
+ copy(paramMap).fit(dataset)
+ }
+
+ /**
+ * Fits a model to the input data.
+ */
+ def fit(dataset: DataFrame): M
/**
* Fits multiple models to the input data with multiple sets of parameters.
* The default implementation uses a for loop on each parameter map.
- * Subclasses could overwrite this to optimize multi-model training.
+ * Subclasses could override this to optimize multi-model training.
*
* @param dataset input dataset
* @param paramMaps An array of parameter maps.
@@ -67,4 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}
+
+ override def copy(extra: ParamMap): Estimator[M] = {
+ super.copy(extra).asInstanceOf[Estimator[M]]
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
index 8b4b5fd8af..5f2f8c94e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
@@ -18,8 +18,7 @@
package org.apache.spark.ml
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.sql.DataFrame
/**
@@ -27,7 +26,7 @@ import org.apache.spark.sql.DataFrame
* Abstract class for evaluators that compute metrics from predictions.
*/
@AlphaComponent
-abstract class Evaluator extends Identifiable {
+abstract class Evaluator extends Params {
/**
* Evaluates the output.
@@ -36,5 +35,18 @@ abstract class Evaluator extends Identifiable {
* @param paramMap parameter map that specifies the input columns and output metrics
* @return metric
*/
- def evaluate(dataset: DataFrame, paramMap: ParamMap): Double
+ def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
+ this.copy(paramMap).evaluate(dataset)
+ }
+
+ /**
+ * Evaluates the output.
+ * @param dataset a dataset that contains labels/observations and predictions.
+ * @return metric
+ */
+ def evaluate(dataset: DataFrame): Double
+
+ override def copy(extra: ParamMap): Evaluator = {
+ super.copy(extra).asInstanceOf[Evaluator]
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index a491bc7ee8..9974efe7b1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -34,9 +34,8 @@ abstract class Model[M <: Model[M]] extends Transformer {
*/
val parent: Estimator[M]
- /**
- * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model.
- * Note: For ensembles' component Models, this value can be null.
- */
- val fittingParamMap: ParamMap
+ override def copy(extra: ParamMap): M = {
+ // The default implementation of Params.copy doesn't work for models.
+ throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 6bfeecd764..33d430f567 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.param.{Params, Param, ParamMap}
+import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -30,40 +30,41 @@ import org.apache.spark.sql.types.StructType
* A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
*/
@AlphaComponent
-abstract class PipelineStage extends Serializable with Logging {
+abstract class PipelineStage extends Params with Logging {
/**
* :: DeveloperApi ::
*
- * Derives the output schema from the input schema and parameters.
- * The schema describes the columns and types of the data.
- *
- * @param schema Input schema to this stage
- * @param paramMap Parameters passed to this stage
- * @return Output schema from this stage
+ * Derives the output schema from the input schema.
*/
@DeveloperApi
- def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+ def transformSchema(schema: StructType): StructType
/**
+ * :: DeveloperApi ::
+ *
* Derives the output schema from the input schema and parameters, optionally with logging.
*
* This should be optimistic. If it is unclear whether the schema will be valid, then it should
* be assumed valid until proven otherwise.
*/
+ @DeveloperApi
protected def transformSchema(
schema: StructType,
- paramMap: ParamMap,
logging: Boolean): StructType = {
if (logging) {
logDebug(s"Input schema: ${schema.json}")
}
- val outputSchema = transformSchema(schema, paramMap)
+ val outputSchema = transformSchema(schema)
if (logging) {
logDebug(s"Expected output schema: ${outputSchema.json}")
}
outputSchema
}
+
+ override def copy(extra: ParamMap): PipelineStage = {
+ super.copy(extra).asInstanceOf[PipelineStage]
+ }
}
/**
@@ -81,15 +82,22 @@ abstract class PipelineStage extends Serializable with Logging {
@AlphaComponent
class Pipeline extends Estimator[PipelineModel] {
- /** param for pipeline stages */
+ /**
+ * param for pipeline stages
+ * @group param
+ */
val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
+
+ /** @group setParam */
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
- def getStages: Array[PipelineStage] = getOrDefault(stages)
- override def validate(paramMap: ParamMap): Unit = {
+ /** @group getParam */
+ def getStages: Array[PipelineStage] = $(stages).clone()
+
+ override def validateParams(paramMap: ParamMap): Unit = {
val map = extractParamMap(paramMap)
getStages.foreach {
- case pStage: Params => pStage.validate(map)
+ case pStage: Params => pStage.validateParams(map)
case _ =>
}
}
@@ -104,13 +112,11 @@ class Pipeline extends Estimator[PipelineModel] {
* pipeline stages. If there are no stages, the output model acts as an identity transformer.
*
* @param dataset input dataset
- * @param paramMap parameter map
* @return fitted pipeline
*/
- override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val theStages = map(stages)
+ override def fit(dataset: DataFrame): PipelineModel = {
+ transformSchema(dataset.schema, logging = true)
+ val theStages = $(stages)
// Search for the last estimator.
var indexOfLastEstimator = -1
theStages.view.zipWithIndex.foreach { case (stage, index) =>
@@ -126,7 +132,7 @@ class Pipeline extends Estimator[PipelineModel] {
if (index <= indexOfLastEstimator) {
val transformer = stage match {
case estimator: Estimator[_] =>
- estimator.fit(curDataset, paramMap)
+ estimator.fit(curDataset)
case t: Transformer =>
t
case _ =>
@@ -134,7 +140,7 @@ class Pipeline extends Estimator[PipelineModel] {
s"Do not support stage $stage of type ${stage.getClass}")
}
if (index < indexOfLastEstimator) {
- curDataset = transformer.transform(curDataset, paramMap)
+ curDataset = transformer.transform(curDataset)
}
transformers += transformer
} else {
@@ -142,15 +148,20 @@ class Pipeline extends Estimator[PipelineModel] {
}
}
- new PipelineModel(this, map, transformers.toArray)
+ new PipelineModel(this, transformers.toArray)
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- val theStages = map(stages)
+ override def copy(extra: ParamMap): Pipeline = {
+ val map = extractParamMap(extra)
+ val newStages = map(stages).map(_.copy(extra))
+ new Pipeline().setStages(newStages)
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ val theStages = $(stages)
require(theStages.toSet.size == theStages.length,
"Cannot have duplicate components in a pipeline.")
- theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap))
+ theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
}
}
@@ -161,43 +172,24 @@ class Pipeline extends Estimator[PipelineModel] {
@AlphaComponent
class PipelineModel private[ml] (
override val parent: Pipeline,
- override val fittingParamMap: ParamMap,
- private[ml] val stages: Array[Transformer])
+ val stages: Array[Transformer])
extends Model[PipelineModel] with Logging {
- override def validate(paramMap: ParamMap): Unit = {
- val map = fittingParamMap ++ extractParamMap(paramMap)
- stages.foreach(_.validate(map))
+ override def validateParams(): Unit = {
+ super.validateParams()
+ stages.foreach(_.validateParams())
}
- /**
- * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
- * estimator does not exist in the pipeline.
- */
- def getModel[M <: Model[M]](stage: Estimator[M]): M = {
- val matched = stages.filter {
- case m: Model[_] => m.parent.eq(stage)
- case _ => false
- }
- if (matched.isEmpty) {
- throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.")
- } else if (matched.length > 1) {
- throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.")
- } else {
- matched.head.asInstanceOf[M]
- }
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur))
}
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
- val map = fittingParamMap ++ extractParamMap(paramMap)
- transformSchema(dataset.schema, map, logging = true)
- stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
+ override def transformSchema(schema: StructType): StructType = {
+ stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
- val map = fittingParamMap ++ extractParamMap(paramMap)
- stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
+ override def copy(extra: ParamMap): PipelineModel = {
+ new PipelineModel(parent, stages)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 0acda71ec6..d96b54e511 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -37,13 +37,18 @@ abstract class Transformer extends PipelineStage with Params {
/**
* Transforms the dataset with optional parameters
* @param dataset input dataset
- * @param paramPairs optional list of param pairs, overwrite embedded params
+ * @param firstParamPair the first param pair, overwrite embedded params
+ * @param otherParamPairs other param pairs, overwrite embedded params
* @return transformed dataset
*/
@varargs
- def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = {
+ def transform(
+ dataset: DataFrame,
+ firstParamPair: ParamPair[_],
+ otherParamPairs: ParamPair[_]*): DataFrame = {
val map = new ParamMap()
- paramPairs.foreach(map.put(_))
+ .put(firstParamPair)
+ .put(otherParamPairs: _*)
transform(dataset, map)
}
@@ -53,7 +58,18 @@ abstract class Transformer extends PipelineStage with Params {
* @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame
+ def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ this.copy(paramMap).transform(dataset)
+ }
+
+ /**
+ * Transforms the input dataset.
+ */
+ def transform(dataset: DataFrame): DataFrame
+
+ override def copy(extra: ParamMap): Transformer = {
+ super.copy(extra).asInstanceOf[Transformer]
+ }
}
/**
@@ -74,7 +90,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
* account of the embedded param map. So the param values should be determined solely by the input
* param map.
*/
- protected def createTransformFunc(paramMap: ParamMap): IN => OUT
+ protected def createTransformFunc: IN => OUT
/**
* Returns the data type of the output column.
@@ -86,22 +102,20 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
*/
protected def validateInputType(inputType: DataType): Unit = {}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- val inputType = schema(map(inputCol)).dataType
+ override def transformSchema(schema: StructType): StructType = {
+ val inputType = schema($(inputCol)).dataType
validateInputType(inputType)
- if (schema.fieldNames.contains(map(outputCol))) {
- throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
+ if (schema.fieldNames.contains($(outputCol))) {
+ throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
}
val outputFields = schema.fields :+
- StructField(map(outputCol), outputDataType, nullable = false)
+ StructField($(outputCol), outputDataType, nullable = false)
StructType(outputFields)
}
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- dataset.withColumn(map(outputCol),
- callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol))))
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ dataset.withColumn($(outputCol),
+ callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 29339c98f5..d3361e2470 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
-import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -27,7 +26,6 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
-
/**
* :: DeveloperApi ::
* Params for classification.
@@ -40,12 +38,10 @@ private[spark] trait ClassifierParams extends PredictorParams
override protected def validateAndTransformSchema(
schema: StructType,
- paramMap: ParamMap,
fitting: Boolean,
featuresDataType: DataType): StructType = {
- val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
- val map = extractParamMap(paramMap)
- SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
+ val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
}
}
@@ -102,27 +98,16 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]].
*
* @param dataset input dataset
- * @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ override def transform(dataset: DataFrame): DataFrame = {
// This default implementation should be overridden as needed.
// Check schema
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
-
- // Prepare model
- val tmpModel = if (paramMap.size != 0) {
- val tmpModel = this.copy()
- Params.inheritValues(paramMap, parent, tmpModel)
- tmpModel
- } else {
- this
- }
+ transformSchema(dataset.schema, logging = true)
val (numColsOutput, outputData) =
- ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map)
+ ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this)
if (numColsOutput == 0) {
logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
@@ -158,7 +143,6 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
*/
@DeveloperApi
protected def predictRaw(features: FeaturesType): Vector
-
}
private[ml] object ClassificationModel {
@@ -167,38 +151,35 @@ private[ml] object ClassificationModel {
* Added prediction column(s). This is separated from [[ClassificationModel.transform()]]
* since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]].
* @param dataset Input dataset
- * @param map Parameter map. This will NOT be merged with the embedded paramMap; the merge
- * should already be done.
* @return (number of columns added, transformed dataset)
*/
def transformColumnsImpl[FeaturesType](
dataset: DataFrame,
- model: ClassificationModel[FeaturesType, _],
- map: ParamMap): (Int, DataFrame) = {
+ model: ClassificationModel[FeaturesType, _]): (Int, DataFrame) = {
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var tmpData = dataset
var numColsOutput = 0
- if (map(model.rawPredictionCol) != "") {
+ if (model.getRawPredictionCol != "") {
// output raw prediction
val features2raw: FeaturesType => Vector = model.predictRaw
- tmpData = tmpData.withColumn(map(model.rawPredictionCol),
- callUDF(features2raw, new VectorUDT, col(map(model.featuresCol))))
+ tmpData = tmpData.withColumn(model.getRawPredictionCol,
+ callUDF(features2raw, new VectorUDT, col(model.getFeaturesCol)))
numColsOutput += 1
- if (map(model.predictionCol) != "") {
+ if (model.getPredictionCol != "") {
val raw2pred: Vector => Double = (rawPred) => {
rawPred.toArray.zipWithIndex.maxBy(_._1)._2
}
- tmpData = tmpData.withColumn(map(model.predictionCol),
- callUDF(raw2pred, DoubleType, col(map(model.rawPredictionCol))))
+ tmpData = tmpData.withColumn(model.getPredictionCol,
+ callUDF(raw2pred, DoubleType, col(model.getRawPredictionCol)))
numColsOutput += 1
}
- } else if (map(model.predictionCol) != "") {
+ } else if (model.getPredictionCol != "") {
// output prediction
val features2pred: FeaturesType => Double = model.predict
- tmpData = tmpData.withColumn(map(model.predictionCol),
- callUDF(features2pred, DoubleType, col(map(model.featuresCol))))
+ tmpData = tmpData.withColumn(model.getPredictionCol,
+ callUDF(features2pred, DoubleType, col(model.getFeaturesCol)))
numColsOutput += 1
}
(numColsOutput, tmpData)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index ee2a8dc6db..419e5ba05d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -18,9 +18,9 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
@@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* :: AlphaComponent ::
*
@@ -64,22 +63,20 @@ final class DecisionTreeClassifier
override def setImpurity(value: String): this.type = super.setImpurity(value)
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): DecisionTreeClassificationModel = {
+ override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
- val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
- s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ s" with invalid label column ${$(labelCol)}, without the number of classes" +
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
val oldModel = OldDecisionTree.train(oldDataset, strategy)
- DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ DecisionTreeClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
/** (private[ml]) Create a Strategy instance to use with the old API. */
@@ -106,7 +103,6 @@ object DecisionTreeClassifier {
@AlphaComponent
final class DecisionTreeClassificationModel private[ml] (
override val parent: DecisionTreeClassifier,
- override val fittingParamMap: ParamMap,
override val rootNode: Node)
extends PredictionModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
@@ -118,10 +114,8 @@ final class DecisionTreeClassificationModel private[ml] (
rootNode.predict(features)
}
- override protected def copy(): DecisionTreeClassificationModel = {
- val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
+ copyValues(new DecisionTreeClassificationModel(parent, rootNode), extra)
}
override def toString: String = {
@@ -140,12 +134,11 @@ private[ml] object DecisionTreeClassificationModel {
def fromOld(
oldModel: OldDecisionTreeModel,
parent: DecisionTreeClassifier,
- fittingParamMap: ParamMap,
categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
require(oldModel.algo == OldAlgo.Classification,
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
- new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+ new DecisionTreeClassificationModel(parent, rootNode)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 3d849867d4..534ea95b1c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -23,7 +23,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.param.{Param, Params, ParamMap}
+import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
@@ -31,12 +31,11 @@ import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{Loss => OldLoss, LogLoss => OldLogLoss}
+import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* :: AlphaComponent ::
*
@@ -112,7 +111,7 @@ final class GBTClassifier
def setLossType(value: String): this.type = set(lossType, value)
/** @group getParam */
- def getLossType: String = getOrDefault(lossType).toLowerCase
+ def getLossType: String = $(lossType).toLowerCase
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
@@ -124,25 +123,23 @@ final class GBTClassifier
}
}
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): GBTClassificationModel = {
+ override protected def train(dataset: DataFrame): GBTClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
- val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("GBTClassifier was given input" +
- s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ s" with invalid label column ${$(labelCol)}, without the number of classes" +
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
require(numClasses == 2,
s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val oldGBT = new OldGBT(boostingStrategy)
val oldModel = oldGBT.run(oldDataset)
- GBTClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
}
@@ -165,7 +162,6 @@ object GBTClassifier {
@AlphaComponent
final class GBTClassificationModel(
override val parent: GBTClassifier,
- override val fittingParamMap: ParamMap,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double])
extends PredictionModel[Vector, GBTClassificationModel]
@@ -188,10 +184,8 @@ final class GBTClassificationModel(
if (prediction > 0.0) 1.0 else 0.0
}
- override protected def copy(): GBTClassificationModel = {
- val m = new GBTClassificationModel(parent, fittingParamMap, _trees, _treeWeights)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): GBTClassificationModel = {
+ copyValues(new GBTClassificationModel(parent, _trees, _treeWeights), extra)
}
override def toString: String = {
@@ -210,14 +204,13 @@ private[ml] object GBTClassificationModel {
def fromOld(
oldModel: OldGBTModel,
parent: GBTClassifier,
- fittingParamMap: ParamMap,
categoricalFeatures: Map[Int, Int]): GBTClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
- DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+ DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new GBTClassificationModel(parent, fittingParamMap, newTrees, oldModel.treeWeights)
+ new GBTClassificationModel(parent, newTrees, oldModel.treeWeights)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index cc8b0721cf..b73be035e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -21,12 +21,11 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
-
/**
* Params for logistic regression.
*/
@@ -59,9 +58,9 @@ class LogisticRegression
/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
- override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
+ override protected def train(dataset: DataFrame): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
- val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val oldDataset = extractLabeledPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
@@ -69,17 +68,17 @@ class LogisticRegression
// Train model
val lr = new LogisticRegressionWithLBFGS()
- .setIntercept(paramMap(fitIntercept))
+ .setIntercept($(fitIntercept))
lr.optimizer
- .setRegParam(paramMap(regParam))
- .setNumIterations(paramMap(maxIter))
+ .setRegParam($(regParam))
+ .setNumIterations($(maxIter))
val oldModel = lr.run(oldDataset)
- val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept)
+ val lrm = new LogisticRegressionModel(this, oldModel.weights, oldModel.intercept)
if (handlePersistence) {
oldDataset.unpersist()
}
- lrm
+ copyValues(lrm)
}
}
@@ -92,7 +91,6 @@ class LogisticRegression
@AlphaComponent
class LogisticRegressionModel private[ml] (
override val parent: LogisticRegression,
- override val fittingParamMap: ParamMap,
val weights: Vector,
val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
@@ -110,16 +108,14 @@ class LogisticRegressionModel private[ml] (
1.0 / (1.0 + math.exp(-m))
}
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ override def transform(dataset: DataFrame): DataFrame = {
// This is overridden (a) to be more efficient (avoiding re-computing values when creating
// multiple output columns) and (b) to handle threshold, which the abstractions do not use.
// TODO: We should abstract away the steps defined by UDFs below so that the abstractions
// can call whichever UDFs are needed to create the output columns.
// Check schema
- transformSchema(dataset.schema, paramMap, logging = true)
-
- val map = extractParamMap(paramMap)
+ transformSchema(dataset.schema, logging = true)
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
@@ -128,41 +124,41 @@ class LogisticRegressionModel private[ml] (
// prediction (max margin)
var tmpData = dataset
var numColsOutput = 0
- if (map(rawPredictionCol) != "") {
+ if ($(rawPredictionCol) != "") {
val features2raw: Vector => Vector = (features) => predictRaw(features)
- tmpData = tmpData.withColumn(map(rawPredictionCol),
- callUDF(features2raw, new VectorUDT, col(map(featuresCol))))
+ tmpData = tmpData.withColumn($(rawPredictionCol),
+ callUDF(features2raw, new VectorUDT, col($(featuresCol))))
numColsOutput += 1
}
- if (map(probabilityCol) != "") {
- if (map(rawPredictionCol) != "") {
+ if ($(probabilityCol) != "") {
+ if ($(rawPredictionCol) != "") {
val raw2prob = udf { (rawPreds: Vector) =>
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
Vectors.dense(1.0 - prob1, prob1): Vector
}
- tmpData = tmpData.withColumn(map(probabilityCol), raw2prob(col(map(rawPredictionCol))))
+ tmpData = tmpData.withColumn($(probabilityCol), raw2prob(col($(rawPredictionCol))))
} else {
val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector }
- tmpData = tmpData.withColumn(map(probabilityCol), features2prob(col(map(featuresCol))))
+ tmpData = tmpData.withColumn($(probabilityCol), features2prob(col($(featuresCol))))
}
numColsOutput += 1
}
- if (map(predictionCol) != "") {
- val t = map(threshold)
- if (map(probabilityCol) != "") {
+ if ($(predictionCol) != "") {
+ val t = $(threshold)
+ if ($(probabilityCol) != "") {
val predict = udf { probs: Vector =>
if (probs(1) > t) 1.0 else 0.0
}
- tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(probabilityCol))))
- } else if (map(rawPredictionCol) != "") {
+ tmpData = tmpData.withColumn($(predictionCol), predict(col($(probabilityCol))))
+ } else if ($(rawPredictionCol) != "") {
val predict = udf { rawPreds: Vector =>
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
if (prob1 > t) 1.0 else 0.0
}
- tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(rawPredictionCol))))
+ tmpData = tmpData.withColumn($(predictionCol), predict(col($(rawPredictionCol))))
} else {
val predict = udf { features: Vector => this.predict(features) }
- tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(featuresCol))))
+ tmpData = tmpData.withColumn($(predictionCol), predict(col($(featuresCol))))
}
numColsOutput += 1
}
@@ -193,9 +189,7 @@ class LogisticRegressionModel private[ml] (
Vectors.dense(0.0, m)
}
- override protected def copy(): LogisticRegressionModel = {
- val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): LogisticRegressionModel = {
+ copyValues(new LogisticRegressionModel(parent, weights, intercept), extra)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 10404548cc..8519841c5c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -18,7 +18,6 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -34,12 +33,10 @@ private[classification] trait ProbabilisticClassifierParams
override protected def validateAndTransformSchema(
schema: StructType,
- paramMap: ParamMap,
fitting: Boolean,
featuresDataType: DataType): StructType = {
- val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
- val map = extractParamMap(paramMap)
- SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT)
+ val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ SchemaUtils.appendColumn(parentSchema, $(probabilityCol), new VectorUDT)
}
}
@@ -95,36 +92,22 @@ private[spark] abstract class ProbabilisticClassificationModel[
* - probability of each class as [[probabilityCol]] of type [[Vector]].
*
* @param dataset input dataset
- * @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ override def transform(dataset: DataFrame): DataFrame = {
// This default implementation should be overridden as needed.
// Check schema
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
-
- // Prepare model
- val tmpModel = if (paramMap.size != 0) {
- val tmpModel = this.copy()
- Params.inheritValues(paramMap, parent, tmpModel)
- tmpModel
- } else {
- this
- }
+ transformSchema(dataset.schema, logging = true)
val (numColsOutput, outputData) =
- ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map)
+ ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this)
// Output selected columns only.
- if (map(probabilityCol) != "") {
+ if ($(probabilityCol) != "") {
// output probabilities
- val features2probs: FeaturesType => Vector = (features) => {
- tmpModel.predictProbabilities(features)
- }
- outputData.withColumn(map(probabilityCol),
- callUDF(features2probs, new VectorUDT, col(map(featuresCol))))
+ outputData.withColumn($(probabilityCol),
+ callUDF(predictProbabilities _, new VectorUDT, col($(featuresCol))))
} else {
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index cfd6508fce..17f59bb42e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -22,18 +22,17 @@ import scala.collection.mutable
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* :: AlphaComponent ::
*
@@ -81,24 +80,22 @@ final class RandomForestClassifier
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): RandomForestClassificationModel = {
+ override protected def train(dataset: DataFrame): RandomForestClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
- val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("RandomForestClassifier was given input" +
- s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ s" with invalid label column ${$(labelCol)}, without the number of classes" +
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
val oldModel = OldRandomForest.trainClassifier(
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
- RandomForestClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
}
@@ -123,7 +120,6 @@ object RandomForestClassifier {
@AlphaComponent
final class RandomForestClassificationModel private[ml] (
override val parent: RandomForestClassifier,
- override val fittingParamMap: ParamMap,
private val _trees: Array[DecisionTreeClassificationModel])
extends PredictionModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
@@ -150,10 +146,8 @@ final class RandomForestClassificationModel private[ml] (
votes.maxBy(_._2)._1
}
- override protected def copy(): RandomForestClassificationModel = {
- val m = new RandomForestClassificationModel(parent, fittingParamMap, _trees)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): RandomForestClassificationModel = {
+ copyValues(new RandomForestClassificationModel(parent, _trees), extra)
}
override def toString: String = {
@@ -172,14 +166,13 @@ private[ml] object RandomForestClassificationModel {
def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
- fittingParamMap: ParamMap,
categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
- DecisionTreeClassificationModel.fromOld(tree, null, null, categoricalFeatures)
+ DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestClassificationModel(parent, fittingParamMap, newTrees)
+ new RandomForestClassificationModel(parent, newTrees)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index c865eb9fe0..e5a73c6087 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -33,8 +33,7 @@ import org.apache.spark.sql.types.DoubleType
* Evaluator for binary classification, which expects two input columns: score and label.
*/
@AlphaComponent
-class BinaryClassificationEvaluator extends Evaluator with Params
- with HasRawPredictionCol with HasLabelCol {
+class BinaryClassificationEvaluator extends Evaluator with HasRawPredictionCol with HasLabelCol {
/**
* param for metric name in evaluation
@@ -44,7 +43,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
"metric name in evaluation (areaUnderROC|areaUnderPR)")
/** @group getParam */
- def getMetricName: String = getOrDefault(metricName)
+ def getMetricName: String = $(metricName)
/** @group setParam */
def setMetricName(value: String): this.type = set(metricName, value)
@@ -57,20 +56,18 @@ class BinaryClassificationEvaluator extends Evaluator with Params
setDefault(metricName -> "areaUnderROC")
- override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
- val map = extractParamMap(paramMap)
-
+ override def evaluate(dataset: DataFrame): Double = {
val schema = dataset.schema
- SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT)
- SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType)
+ SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT)
+ SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
- val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol))
+ val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol))
.map { case Row(rawPrediction: Vector, label: Double) =>
(rawPrediction(1), label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
- val metric = map(metricName) match {
+ val metric = $(metricName) match {
case "areaUnderROC" =>
metrics.areaUnderROC()
case "areaUnderPR" =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index f3ce6dfca2..6eb1db6971 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -44,7 +44,7 @@ final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
new DoubleParam(this, "threshold", "threshold used to binarize continuous features")
/** @group getParam */
- def getThreshold: Double = getOrDefault(threshold)
+ def getThreshold: Double = $(threshold)
/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
@@ -57,23 +57,21 @@ final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val td = map(threshold)
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ val td = $(threshold)
val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
- val outputColName = map(outputCol)
+ val outputColName = $(outputCol)
val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
dataset.select(col("*"),
- binarizer(col(map(inputCol))).as(outputColName, metadata))
+ binarizer(col($(inputCol))).as(outputColName, metadata))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType)
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
- val outputColName = map(outputCol)
+ val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 0b3128f9ee..c305a819a8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -19,9 +19,9 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
+import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
/**
@@ -42,13 +42,13 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
setDefault(numFeatures -> (1 << 18))
/** @group getParam */
- def getNumFeatures: Int = getOrDefault(numFeatures)
+ def getNumFeatures: Int = $(numFeatures)
/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
- override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
- val hashingTF = new feature.HashingTF(paramMap(numFeatures))
+ override protected def createTransformFunc: Iterable[_] => Vector = {
+ val hashingTF = new feature.HashingTF($(numFeatures))
hashingTF.transform
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index e6a62d998b..d901a20aed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -43,7 +43,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
setDefault(minDocFreq -> 0)
/** @group getParam */
- def getMinDocFreq: Int = getOrDefault(minDocFreq)
+ def getMinDocFreq: Int = $(minDocFreq)
/** @group setParam */
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
@@ -51,10 +51,9 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
/**
* Validate and transform the input schema.
*/
- protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT)
- SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
}
@@ -71,18 +70,15 @@ final class IDF extends Estimator[IDFModel] with IDFBase {
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
- val idf = new feature.IDF(map(minDocFreq)).fit(input)
- val model = new IDFModel(this, map, idf)
- Params.inheritValues(map, this, model)
- model
+ override def fit(dataset: DataFrame): IDFModel = {
+ transformSchema(dataset.schema, logging = true)
+ val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
+ val idf = new feature.IDF($(minDocFreq)).fit(input)
+ copyValues(new IDFModel(this, idf))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
@@ -93,7 +89,6 @@ final class IDF extends Estimator[IDFModel] with IDFBase {
@AlphaComponent
class IDFModel private[ml] (
override val parent: IDF,
- override val fittingParamMap: ParamMap,
idfModel: feature.IDFModel)
extends Model[IDFModel] with IDFBase {
@@ -103,14 +98,13 @@ class IDFModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val idf = udf { vec: Vector => idfModel.transform(vec) }
- dataset.withColumn(map(outputCol), idf(col(map(inputCol))))
+ dataset.withColumn($(outputCol), idf(col($(inputCol))))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index bd2b5f6067..755b46a64c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -19,9 +19,9 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{ParamValidators, DoubleParam, ParamMap}
+import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
/**
@@ -41,13 +41,13 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
setDefault(p -> 2.0)
/** @group getParam */
- def getP: Double = getOrDefault(p)
+ def getP: Double = $(p)
/** @group setParam */
def setP(value: Double): this.type = set(p, value)
- override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = {
- val normalizer = new feature.Normalizer(paramMap(p))
+ override protected def createTransformFunc: Vector => Vector = {
+ val normalizer = new feature.Normalizer($(p))
normalizer.transform
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index 1b7c939c2d..63e190c8aa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
+import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.types.DataType
@@ -47,14 +47,13 @@ class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExp
setDefault(degree -> 2)
/** @group getParam */
- def getDegree: Int = getOrDefault(degree)
+ def getDegree: Int = $(degree)
/** @group setParam */
def setDegree(value: Int): this.type = set(degree, value)
- override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { v =>
- val d = paramMap(degree)
- PolynomialExpansion.expand(v, d)
+ override protected def createTransformFunc: Vector => Vector = { v =>
+ PolynomialExpansion.expand(v, $(degree))
}
override protected def outputDataType: DataType = new VectorUDT()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index a0e9ed32e0..7cad59ff3f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -71,25 +71,21 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
/** @group setParam */
def setWithStd(value: Boolean): this.type = set(withStd, value)
- override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
- val scaler = new feature.StandardScaler(withMean = map(withMean), withStd = map(withStd))
+ override def fit(dataset: DataFrame): StandardScalerModel = {
+ transformSchema(dataset.schema, logging = true)
+ val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
+ val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
val scalerModel = scaler.fit(input)
- val model = new StandardScalerModel(this, map, scalerModel)
- Params.inheritValues(map, this, model)
- model
+ copyValues(new StandardScalerModel(this, scalerModel))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- val inputType = schema(map(inputCol)).dataType
+ override def transformSchema(schema: StructType): StructType = {
+ val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
- s"Input column ${map(inputCol)} must be a vector column")
- require(!schema.fieldNames.contains(map(outputCol)),
- s"Output column ${map(outputCol)} already exists.")
- val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
+ s"Input column ${$(inputCol)} must be a vector column")
+ require(!schema.fieldNames.contains($(outputCol)),
+ s"Output column ${$(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}
}
@@ -101,7 +97,6 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
@AlphaComponent
class StandardScalerModel private[ml] (
override val parent: StandardScaler,
- override val fittingParamMap: ParamMap,
scaler: feature.StandardScalerModel)
extends Model[StandardScalerModel] with StandardScalerParams {
@@ -111,21 +106,19 @@ class StandardScalerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val scale = udf((v: Vector) => { scaler.transform(v) } : Vector)
- dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ val scale = udf { scaler.transform _ }
+ dataset.withColumn($(outputCol), scale(col($(inputCol))))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- val inputType = schema(map(inputCol)).dataType
+ override def transformSchema(schema: StructType): StructType = {
+ val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
- s"Input column ${map(inputCol)} must be a vector column")
- require(!schema.fieldNames.contains(map(outputCol)),
- s"Output column ${map(outputCol)} already exists.")
- val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
+ s"Input column ${$(inputCol)} must be a vector column")
+ require(!schema.fieldNames.contains($(outputCol)),
+ s"Output column ${$(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 9db3b29e10..3d78537ad8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -34,18 +34,17 @@ import org.apache.spark.util.collection.OpenHashMap
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
/** Validates and transforms the input schema. */
- protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- val inputColName = map(inputCol)
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be either string type or numeric type, " +
s"but got $inputDataType.")
val inputFields = schema.fields
- val outputColName = map(outputCol)
+ val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
- val attr = NominalAttribute.defaultAttr.withName(map(outputCol))
+ val attr = NominalAttribute.defaultAttr.withName($(outputCol))
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}
@@ -69,19 +68,16 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
// TODO: handle unseen labels
- override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
- val map = extractParamMap(paramMap)
- val counts = dataset.select(col(map(inputCol)).cast(StringType))
+ override def fit(dataset: DataFrame): StringIndexerModel = {
+ val counts = dataset.select(col($(inputCol)).cast(StringType))
.map(_.getString(0))
.countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
- val model = new StringIndexerModel(this, map, labels)
- Params.inheritValues(map, this, model)
- model
+ copyValues(new StringIndexerModel(this, labels))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
@@ -92,7 +88,6 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
@AlphaComponent
class StringIndexerModel private[ml] (
override val parent: StringIndexer,
- override val fittingParamMap: ParamMap,
labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
private val labelToIndex: OpenHashMap[String, Double] = {
@@ -112,8 +107,7 @@ class StringIndexerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- val map = extractParamMap(paramMap)
+ override def transform(dataset: DataFrame): DataFrame = {
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
@@ -122,14 +116,14 @@ class StringIndexerModel private[ml] (
throw new SparkException(s"Unseen label: $label.")
}
}
- val outputColName = map(outputCol)
+ val outputColName = $(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toMetadata()
dataset.select(col("*"),
- indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata))
+ indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 01752ba482..2863b76215 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
-import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
+import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
/**
* :: AlphaComponent ::
@@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
@AlphaComponent
class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
- override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
+ override protected def createTransformFunc: String => Seq[String] = {
_.toLowerCase.split("\\s")
}
@@ -62,7 +62,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)
/** @group getParam */
- def getMinTokenLength: Int = getOrDefault(minTokenLength)
+ def getMinTokenLength: Int = $(minTokenLength)
/**
* Indicates whether regex splits on gaps (true) or matching tokens (false).
@@ -75,7 +75,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
def setGaps(value: Boolean): this.type = set(gaps, value)
/** @group getParam */
- def getGaps: Boolean = getOrDefault(gaps)
+ def getGaps: Boolean = $(gaps)
/**
* Regex pattern used by tokenizer.
@@ -88,14 +88,14 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
def setPattern(value: String): this.type = set(pattern, value)
/** @group getParam */
- def getPattern: String = getOrDefault(pattern)
+ def getPattern: String = $(pattern)
setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+")
- override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str =>
- val re = paramMap(pattern).r
- val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
- val minLength = paramMap(minTokenLength)
+ override protected def createTransformFunc: String => Seq[String] = { str =>
+ val re = $(pattern).r
+ val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
+ val minLength = $(minTokenLength)
tokens.filter(_.length >= minLength)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 5e781a326d..8f2e62a8e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -22,7 +22,6 @@ import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{DataFrame, Row}
@@ -42,13 +41,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- val map = extractParamMap(paramMap)
+ override def transform(dataset: DataFrame): DataFrame = {
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*)
}
val schema = dataset.schema
- val inputColNames = map(inputCols)
+ val inputColNames = $(inputCols)
val args = inputColNames.map { c =>
schema(c).dataType match {
case DoubleType => dataset(c)
@@ -56,13 +54,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
}
}
- dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol)))
+ dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol)))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- val inputColNames = map(inputCols)
- val outputColName = map(outputCol)
+ override def transformSchema(schema: StructType): StructType = {
+ val inputColNames = $(inputCols)
+ val outputColName = $(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
inputDataTypes.foreach {
case _: NumericType | BooleanType =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index ed833c63c7..07ea579d69 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -18,19 +18,17 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute,
- Attribute, AttributeGroup}
-import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap, Params}
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
import org.apache.spark.ml.param.shared._
-import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
-import org.apache.spark.sql.{Row, DataFrame}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.callUDF
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.collection.OpenHashSet
-
/** Private trait for params for VectorIndexer and VectorIndexerModel */
private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol {
@@ -49,7 +47,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
setDefault(maxCategories -> 20)
/** @group getParam */
- def getMaxCategories: Int = getOrDefault(maxCategories)
+ def getMaxCategories: Int = $(maxCategories)
}
/**
@@ -100,33 +98,29 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val firstRow = dataset.select(map(inputCol)).take(1)
+ override def fit(dataset: DataFrame): VectorIndexerModel = {
+ transformSchema(dataset.schema, logging = true)
+ val firstRow = dataset.select($(inputCol)).take(1)
require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.")
val numFeatures = firstRow(0).getAs[Vector](0).size
- val vectorDataset = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
- val maxCats = map(maxCategories)
+ val vectorDataset = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
+ val maxCats = $(maxCategories)
val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter =>
val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats)
iter.foreach(localCatStats.addVector)
Iterator(localCatStats)
}.reduce((stats1, stats2) => stats1.merge(stats2))
- val model = new VectorIndexerModel(this, map, numFeatures, categoryStats.getCategoryMaps)
- Params.inheritValues(map, this, model)
- model
+ copyValues(new VectorIndexerModel(this, numFeatures, categoryStats.getCategoryMaps))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType): StructType = {
// We do not transfer feature metadata since we do not know what types of features we will
// produce in transform().
- val map = extractParamMap(paramMap)
val dataType = new VectorUDT
- require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol")
- require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol")
- SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
- SchemaUtils.appendColumn(schema, map(outputCol), dataType)
+ require(isDefined(inputCol), s"VectorIndexer requires input column parameter: $inputCol")
+ require(isDefined(outputCol), s"VectorIndexer requires output column parameter: $outputCol")
+ SchemaUtils.checkColumnType(schema, $(inputCol), dataType)
+ SchemaUtils.appendColumn(schema, $(outputCol), dataType)
}
}
@@ -243,7 +237,6 @@ private object VectorIndexer {
@AlphaComponent
class VectorIndexerModel private[ml] (
override val parent: VectorIndexer,
- override val fittingParamMap: ParamMap,
val numFeatures: Int,
val categoryMaps: Map[Int, Map[Double, Int]])
extends Model[VectorIndexerModel] with VectorIndexerParams {
@@ -326,35 +319,33 @@ class VectorIndexerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val newField = prepOutputField(dataset.schema, map)
- val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
- dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ val newField = prepOutputField(dataset.schema)
+ val newCol = callUDF(transformFunc, new VectorUDT, dataset($(inputCol)))
+ dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
+ override def transformSchema(schema: StructType): StructType = {
val dataType = new VectorUDT
- require(map.contains(inputCol),
+ require(isDefined(inputCol),
s"VectorIndexerModel requires input column parameter: $inputCol")
- require(map.contains(outputCol),
+ require(isDefined(outputCol),
s"VectorIndexerModel requires output column parameter: $outputCol")
- SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
+ SchemaUtils.checkColumnType(schema, $(inputCol), dataType)
// If the input metadata specifies numFeatures, compare with expected numFeatures.
- val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
+ val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol)))
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
Some(origAttrGroup.attributes.get.length)
} else {
origAttrGroup.numAttributes
}
require(origNumFeatures.forall(_ == numFeatures), "VectorIndexerModel expected" +
- s" $numFeatures features, but input column ${map(inputCol)} had metadata specifying" +
+ s" $numFeatures features, but input column ${$(inputCol)} had metadata specifying" +
s" ${origAttrGroup.numAttributes.get} features.")
- val newField = prepOutputField(schema, map)
+ val newField = prepOutputField(schema)
val outputFields = schema.fields :+ newField
StructType(outputFields)
}
@@ -362,11 +353,10 @@ class VectorIndexerModel private[ml] (
/**
* Prepare the output column field, including per-feature metadata.
* @param schema Input schema
- * @param map Parameter map (with this class' embedded parameter map folded in)
* @return Output column field. This field does not contain non-ML metadata.
*/
- private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
- val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
+ private def prepOutputField(schema: StructType): StructField = {
+ val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol)))
val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
// Convert original attributes to modified attributes
val origAttrs: Array[Attribute] = origAttrGroup.attributes.get
@@ -389,7 +379,7 @@ class VectorIndexerModel private[ml] (
} else {
partialFeatureAttributes
}
- val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
+ val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
newAttributeGroup.toStructField()
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 0163fa8bd8..34ff929701 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -18,16 +18,16 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
+import org.apache.spark.mllib.linalg.BLAS._
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{DataFrame, Row}
/**
* Params for [[Word2Vec]] and [[Word2VecModel]].
@@ -43,7 +43,7 @@ private[feature] trait Word2VecBase extends Params
setDefault(vectorSize -> 100)
/** @group getParam */
- def getVectorSize: Int = getOrDefault(vectorSize)
+ def getVectorSize: Int = $(vectorSize)
/**
* Number of partitions for sentences of words.
@@ -53,7 +53,7 @@ private[feature] trait Word2VecBase extends Params
setDefault(numPartitions -> 1)
/** @group getParam */
- def getNumPartitions: Int = getOrDefault(numPartitions)
+ def getNumPartitions: Int = $(numPartitions)
/**
* The minimum number of times a token must appear to be included in the word2vec model's
@@ -64,7 +64,7 @@ private[feature] trait Word2VecBase extends Params
setDefault(minCount -> 5)
/** @group getParam */
- def getMinCount: Int = getOrDefault(minCount)
+ def getMinCount: Int = $(minCount)
setDefault(stepSize -> 0.025)
setDefault(maxIter -> 1)
@@ -73,10 +73,9 @@ private[feature] trait Word2VecBase extends Params
/**
* Validate and transform the input schema.
*/
- protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(StringType, true))
- SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
+ SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
}
@@ -112,25 +111,22 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
/** @group setParam */
def setMinCount(value: Int): this.type = set(minCount, value)
- override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val input = dataset.select(map(inputCol)).map { case Row(v: Seq[String]) => v }
+ override def fit(dataset: DataFrame): Word2VecModel = {
+ transformSchema(dataset.schema, logging = true)
+ val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0))
val wordVectors = new feature.Word2Vec()
- .setLearningRate(map(stepSize))
- .setMinCount(map(minCount))
- .setNumIterations(map(maxIter))
- .setNumPartitions(map(numPartitions))
- .setSeed(map(seed))
- .setVectorSize(map(vectorSize))
+ .setLearningRate($(stepSize))
+ .setMinCount($(minCount))
+ .setNumIterations($(maxIter))
+ .setNumPartitions($(numPartitions))
+ .setSeed($(seed))
+ .setVectorSize($(vectorSize))
.fit(input)
- val model = new Word2VecModel(this, map, wordVectors)
- Params.inheritValues(map, this, model)
- model
+ copyValues(new Word2VecModel(this, wordVectors))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
@@ -141,7 +137,6 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
@AlphaComponent
class Word2VecModel private[ml] (
override val parent: Word2Vec,
- override val fittingParamMap: ParamMap,
wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {
@@ -155,15 +150,14 @@ class Word2VecModel private[ml] (
* Transform a sentence column to a vector column to represent the whole sentence. The transform
* is performed by averaging all word vectors it contains.
*/
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
val word2Vec = udf { sentence: Seq[String] =>
if (sentence.size == 0) {
- Vectors.sparse(map(vectorSize), Array.empty[Int], Array.empty[Double])
+ Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double])
} else {
- val cum = Vectors.zeros(map(vectorSize))
+ val cum = Vectors.zeros($(vectorSize))
val model = bWordVectors.value.getVectors
for (word <- sentence) {
if (model.contains(word)) {
@@ -176,10 +170,10 @@ class Word2VecModel private[ml] (
cum
}
}
- dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol))))
+ dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index 195333a5cc..e8b3628140 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -18,18 +18,17 @@
package org.apache.spark.ml.impl.estimator
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
-
/**
* :: DeveloperApi ::
*
@@ -44,7 +43,6 @@ private[spark] trait PredictorParams extends Params
/**
* Validates and transforms the input schema with the provided param map.
* @param schema input schema
- * @param paramMap additional parameters
* @param fitting whether this is in fitting
* @param featuresDataType SQL DataType for FeaturesType.
* E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
@@ -52,17 +50,15 @@ private[spark] trait PredictorParams extends Params
*/
protected def validateAndTransformSchema(
schema: StructType,
- paramMap: ParamMap,
fitting: Boolean,
featuresDataType: DataType): StructType = {
- val map = extractParamMap(paramMap)
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
- SchemaUtils.checkColumnType(schema, map(featuresCol), featuresDataType)
+ SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
// TODO: Allow other numeric types
- SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType)
+ SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
}
- SchemaUtils.appendColumn(schema, map(predictionCol), DoubleType)
+ SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
}
@@ -96,14 +92,15 @@ private[spark] abstract class Predictor[
/** @group setParam */
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
- override def fit(dataset: DataFrame, paramMap: ParamMap): M = {
+ override def fit(dataset: DataFrame): M = {
// This handles a few items such as schema validation.
// Developers only need to implement train().
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val model = train(dataset, map)
- Params.inheritValues(map, this, model) // copy params to model
- model
+ transformSchema(dataset.schema, logging = true)
+ copyValues(train(dataset))
+ }
+
+ override def copy(extra: ParamMap): Learner = {
+ super.copy(extra).asInstanceOf[Learner]
}
/**
@@ -114,12 +111,10 @@ private[spark] abstract class Predictor[
* and copying parameters into the model.
*
* @param dataset Training dataset
- * @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already
- * been combined with the embedded ParamMap.
* @return Fitted model
*/
@DeveloperApi
- protected def train(dataset: DataFrame, paramMap: ParamMap): M
+ protected def train(dataset: DataFrame): M
/**
* :: DeveloperApi ::
@@ -134,17 +129,16 @@ private[spark] abstract class Predictor[
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = true, featuresDataType)
}
/**
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
*/
- protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = {
- val map = extractParamMap(paramMap)
- dataset.select(map(labelCol), map(featuresCol))
+ protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
+ dataset.select($(labelCol), $(featuresCol))
.map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}
@@ -186,8 +180,8 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = false, featuresDataType)
}
/**
@@ -195,30 +189,16 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
* the predictions as a new column [[predictionCol]].
*
* @param dataset input dataset
- * @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset with [[predictionCol]] of type [[Double]]
*/
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ override def transform(dataset: DataFrame): DataFrame = {
// This default implementation should be overridden as needed.
// Check schema
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
-
- // Prepare model
- val tmpModel = if (paramMap.size != 0) {
- val tmpModel = this.copy()
- Params.inheritValues(paramMap, parent, tmpModel)
- tmpModel
- } else {
- this
- }
+ transformSchema(dataset.schema, logging = true)
- if (map(predictionCol) != "") {
- val pred: FeaturesType => Double = (features) => {
- tmpModel.predict(features)
- }
- dataset.withColumn(map(predictionCol), callUDF(pred, DoubleType, col(map(featuresCol))))
+ if ($(predictionCol) != "") {
+ dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
@@ -234,10 +214,4 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
*/
@DeveloperApi
protected def predict(features: FeaturesType): Double
-
- /**
- * Create a copy of the model.
- * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
- */
- protected def copy(): M
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
index fb770622e7..0e225627d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -20,14 +20,11 @@ package org.apache.spark.ml.impl.tree
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.impl.estimator.PredictorParams
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasSeed, HasMaxIter}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo,
- BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
- Impurity => OldImpurity, Variance => OldVariance}
+import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
-
/**
* :: DeveloperApi ::
* Parameters for Decision Tree-based algorithms.
@@ -123,43 +120,43 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
/** @group getParam */
- final def getMaxDepth: Int = getOrDefault(maxDepth)
+ final def getMaxDepth: Int = $(maxDepth)
/** @group setParam */
def setMaxBins(value: Int): this.type = set(maxBins, value)
/** @group getParam */
- final def getMaxBins: Int = getOrDefault(maxBins)
+ final def getMaxBins: Int = $(maxBins)
/** @group setParam */
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group getParam */
- final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
+ final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
/** @group setParam */
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
/** @group getParam */
- final def getMinInfoGain: Double = getOrDefault(minInfoGain)
+ final def getMinInfoGain: Double = $(minInfoGain)
/** @group expertSetParam */
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
/** @group expertGetParam */
- final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
+ final def getMaxMemoryInMB: Int = $(maxMemoryInMB)
/** @group expertSetParam */
def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
/** @group expertGetParam */
- final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
+ final def getCacheNodeIds: Boolean = $(cacheNodeIds)
/** @group expertSetParam */
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
/** @group expertGetParam */
- final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+ final def getCheckpointInterval: Int = $(checkpointInterval)
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(
@@ -206,7 +203,7 @@ private[ml] trait TreeClassifierParams extends Params {
def setImpurity(value: String): this.type = set(impurity, value)
/** @group getParam */
- final def getImpurity: String = getOrDefault(impurity).toLowerCase
+ final def getImpurity: String = $(impurity).toLowerCase
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
@@ -248,7 +245,7 @@ private[ml] trait TreeRegressorParams extends Params {
def setImpurity(value: String): this.type = set(impurity, value)
/** @group getParam */
- final def getImpurity: String = getOrDefault(impurity).toLowerCase
+ final def getImpurity: String = $(impurity).toLowerCase
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
@@ -291,7 +288,7 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
/** @group getParam */
- final def getSubsamplingRate: Double = getOrDefault(subsamplingRate)
+ final def getSubsamplingRate: Double = $(subsamplingRate)
/** @group setParam */
def setSeed(value: Long): this.type = set(seed, value)
@@ -364,13 +361,13 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
def setNumTrees(value: Int): this.type = set(numTrees, value)
/** @group getParam */
- final def getNumTrees: Int = getOrDefault(numTrees)
+ final def getNumTrees: Int = $(numTrees)
/** @group setParam */
def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
/** @group getParam */
- final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy).toLowerCase
+ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
}
private[ml] object RandomForestParams {
@@ -418,7 +415,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
def setStepSize(value: Double): this.type = set(stepSize, value)
/** @group getParam */
- final def getStepSize: Double = getOrDefault(stepSize)
+ final def getStepSize: Double = $(stepSize)
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
private[ml] def getOldBoostingStrategy(
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index df6360dce6..51ce19d29c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -23,7 +23,7 @@ import java.util.NoSuchElementException
import scala.annotation.varargs
import scala.collection.mutable
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.util.Identifiable
/**
@@ -49,7 +49,7 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
* Assert that the given value is valid for this parameter.
*
* Note: Parameter checks involving interactions between multiple parameters should be
- * implemented in [[Params.validate()]]. Checks for input/output columns should be
+ * implemented in [[Params.validateParams()]]. Checks for input/output columns should be
* implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
*
* DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters
@@ -258,7 +258,9 @@ trait Params extends Identifiable with Serializable {
* [[Param.validate()]]. This method does not handle input/output column parameters;
* those are checked during schema validation.
*/
- def validate(paramMap: ParamMap): Unit = { }
+ def validateParams(paramMap: ParamMap): Unit = {
+ copy(paramMap).validateParams()
+ }
/**
* Validates parameter values stored internally.
@@ -269,7 +271,11 @@ trait Params extends Identifiable with Serializable {
* [[Param.validate()]]. This method does not handle input/output column parameters;
* those are checked during schema validation.
*/
- def validate(): Unit = validate(ParamMap.empty)
+ def validateParams(): Unit = {
+ params.filter(isDefined _).foreach { param =>
+ param.asInstanceOf[Param[Any]].validate($(param))
+ }
+ }
/**
* Returns the documentation of all params.
@@ -288,6 +294,11 @@ trait Params extends Identifiable with Serializable {
defaultParamMap.contains(param) || paramMap.contains(param)
}
+ /** Tests whether this instance contains a param with a given name. */
+ def hasParam(paramName: String): Boolean = {
+ params.exists(_.name == paramName)
+ }
+
/** Gets a param by its name. */
def getParam(paramName: String): Param[Any] = {
params.find(_.name == paramName).getOrElse {
@@ -337,6 +348,9 @@ trait Params extends Identifiable with Serializable {
get(param).orElse(getDefault(param)).get
}
+ /** An alias for [[getOrDefault()]]. */
+ protected final def $[T](param: Param[T]): T = getOrDefault(param)
+
/**
* Sets a default value for a param.
* @param param param to set the default value. Make sure that this param is initialized before
@@ -383,18 +397,30 @@ trait Params extends Identifiable with Serializable {
}
/**
+ * Creates a copy of this instance with a randomly generated uid and some extra params.
+ * The default implementation calls the default constructor to create a new instance, then
+ * copies the embedded and extra parameters over and returns the new instance.
+ * Subclasses should override this method if the default approach is not sufficient.
+ */
+ def copy(extra: ParamMap): Params = {
+ val that = this.getClass.newInstance()
+ copyValues(that, extra)
+ that
+ }
+
+ /**
* Extracts the embedded default param values and user-supplied values, and then merges them with
* extra values from input into a flat param map, where the latter value is used if there exist
* conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap.
*/
- protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = {
+ final def extractParamMap(extraParamMap: ParamMap): ParamMap = {
defaultParamMap ++ paramMap ++ extraParamMap
}
/**
* [[extractParamMap]] with no extra values.
*/
- protected final def extractParamMap(): ParamMap = {
+ final def extractParamMap(): ParamMap = {
extractParamMap(ParamMap.empty)
}
@@ -408,34 +434,21 @@ trait Params extends Identifiable with Serializable {
private def shouldOwn(param: Param[_]): Unit = {
require(param.parent.eq(this), s"Param $param does not belong to $this.")
}
-}
-/**
- * :: DeveloperApi ::
- *
- * Helper functionality for developers.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
- */
-@DeveloperApi
-private[spark] object Params {
-
- /**
- * Copies parameter values from the parent estimator to the child model it produced.
- * @param paramMap the param map that holds parameters of the parent
- * @param parent the parent estimator
- * @param child the child model
- */
- def inheritValues[E <: Params, M <: E](
- paramMap: ParamMap,
- parent: E,
- child: M): Unit = {
- val childParams = child.params.map(_.name).toSet
- parent.params.foreach { param =>
- if (paramMap.contains(param) && childParams.contains(param.name)) {
- child.set(child.getParam(param.name), paramMap(param))
+ /**
+ * Copies param values from this instance to another instance for params shared by them.
+ * @param to the target instance
+ * @param extra extra params to be copied
+ * @return the target instance with param values copied
+ */
+ protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = {
+ val map = extractParamMap(extra)
+ params.foreach { param =>
+ if (map.contains(param) && to.hasParam(param.name)) {
+ to.set(param.name, map(param))
}
}
+ to
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 7da4bb4b4b..d379172e0b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -21,8 +21,6 @@ import java.io.PrintWriter
import scala.reflect.ClassTag
-import org.apache.spark.ml.param.ParamValidators
-
/**
* Code generator for shared params (sharedParams.scala). Run under the Spark folder with
* {{{
@@ -142,7 +140,7 @@ private[shared] object SharedParamsCodeGen {
| final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|$setDefault
| /** @group getParam */
- | final def get$Name: $T = getOrDefault($name)
+ | final def get$Name: $T = $$($name)
|}
|""".stripMargin
}
@@ -169,7 +167,6 @@ private[shared] object SharedParamsCodeGen {
|
|package org.apache.spark.ml.param.shared
|
- |import org.apache.spark.annotation.DeveloperApi
|import org.apache.spark.ml.param._
|import org.apache.spark.util.Utils
|
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index e1549f46a6..fb1874ccfc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -17,7 +17,6 @@
package org.apache.spark.ml.param.shared
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param._
import org.apache.spark.util.Utils
@@ -37,7 +36,7 @@ private[ml] trait HasRegParam extends Params {
final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter (>= 0)", ParamValidators.gtEq(0))
/** @group getParam */
- final def getRegParam: Double = getOrDefault(regParam)
+ final def getRegParam: Double = $(regParam)
}
/**
@@ -52,7 +51,7 @@ private[ml] trait HasMaxIter extends Params {
final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0))
/** @group getParam */
- final def getMaxIter: Int = getOrDefault(maxIter)
+ final def getMaxIter: Int = $(maxIter)
}
/**
@@ -69,7 +68,7 @@ private[ml] trait HasFeaturesCol extends Params {
setDefault(featuresCol, "features")
/** @group getParam */
- final def getFeaturesCol: String = getOrDefault(featuresCol)
+ final def getFeaturesCol: String = $(featuresCol)
}
/**
@@ -86,7 +85,7 @@ private[ml] trait HasLabelCol extends Params {
setDefault(labelCol, "label")
/** @group getParam */
- final def getLabelCol: String = getOrDefault(labelCol)
+ final def getLabelCol: String = $(labelCol)
}
/**
@@ -103,7 +102,7 @@ private[ml] trait HasPredictionCol extends Params {
setDefault(predictionCol, "prediction")
/** @group getParam */
- final def getPredictionCol: String = getOrDefault(predictionCol)
+ final def getPredictionCol: String = $(predictionCol)
}
/**
@@ -120,7 +119,7 @@ private[ml] trait HasRawPredictionCol extends Params {
setDefault(rawPredictionCol, "rawPrediction")
/** @group getParam */
- final def getRawPredictionCol: String = getOrDefault(rawPredictionCol)
+ final def getRawPredictionCol: String = $(rawPredictionCol)
}
/**
@@ -137,7 +136,7 @@ private[ml] trait HasProbabilityCol extends Params {
setDefault(probabilityCol, "probability")
/** @group getParam */
- final def getProbabilityCol: String = getOrDefault(probabilityCol)
+ final def getProbabilityCol: String = $(probabilityCol)
}
/**
@@ -152,7 +151,7 @@ private[ml] trait HasThreshold extends Params {
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1))
/** @group getParam */
- final def getThreshold: Double = getOrDefault(threshold)
+ final def getThreshold: Double = $(threshold)
}
/**
@@ -167,7 +166,7 @@ private[ml] trait HasInputCol extends Params {
final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name")
/** @group getParam */
- final def getInputCol: String = getOrDefault(inputCol)
+ final def getInputCol: String = $(inputCol)
}
/**
@@ -182,7 +181,7 @@ private[ml] trait HasInputCols extends Params {
final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names")
/** @group getParam */
- final def getInputCols: Array[String] = getOrDefault(inputCols)
+ final def getInputCols: Array[String] = $(inputCols)
}
/**
@@ -197,7 +196,7 @@ private[ml] trait HasOutputCol extends Params {
final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name")
/** @group getParam */
- final def getOutputCol: String = getOrDefault(outputCol)
+ final def getOutputCol: String = $(outputCol)
}
/**
@@ -212,7 +211,7 @@ private[ml] trait HasCheckpointInterval extends Params {
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1))
/** @group getParam */
- final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+ final def getCheckpointInterval: Int = $(checkpointInterval)
}
/**
@@ -229,7 +228,7 @@ private[ml] trait HasFitIntercept extends Params {
setDefault(fitIntercept, true)
/** @group getParam */
- final def getFitIntercept: Boolean = getOrDefault(fitIntercept)
+ final def getFitIntercept: Boolean = $(fitIntercept)
}
/**
@@ -246,7 +245,7 @@ private[ml] trait HasSeed extends Params {
setDefault(seed, Utils.random.nextLong())
/** @group getParam */
- final def getSeed: Long = getOrDefault(seed)
+ final def getSeed: Long = $(seed)
}
/**
@@ -261,7 +260,7 @@ private[ml] trait HasElasticNetParam extends Params {
final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1))
/** @group getParam */
- final def getElasticNetParam: Double = getOrDefault(elasticNetParam)
+ final def getElasticNetParam: Double = $(elasticNetParam)
}
/**
@@ -276,7 +275,7 @@ private[ml] trait HasTol extends Params {
final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
/** @group getParam */
- final def getTol: Double = getOrDefault(tol)
+ final def getTol: Double = $(tol)
}
/**
@@ -291,6 +290,6 @@ private[ml] trait HasStepSize extends Params {
final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization.")
/** @group getParam */
- final def getStepSize: Double = getOrDefault(stepSize)
+ final def getStepSize: Double = $(stepSize)
}
// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index f9f2b2764d..6cf4b40075 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -59,7 +59,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidators.gtEq(1))
/** @group getParam */
- def getRank: Int = getOrDefault(rank)
+ def getRank: Int = $(rank)
/**
* Param for number of user blocks (>= 1).
@@ -70,7 +70,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
ParamValidators.gtEq(1))
/** @group getParam */
- def getNumUserBlocks: Int = getOrDefault(numUserBlocks)
+ def getNumUserBlocks: Int = $(numUserBlocks)
/**
* Param for number of item blocks (>= 1).
@@ -81,7 +81,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
ParamValidators.gtEq(1))
/** @group getParam */
- def getNumItemBlocks: Int = getOrDefault(numItemBlocks)
+ def getNumItemBlocks: Int = $(numItemBlocks)
/**
* Param to decide whether to use implicit preference.
@@ -91,7 +91,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference")
/** @group getParam */
- def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs)
+ def getImplicitPrefs: Boolean = $(implicitPrefs)
/**
* Param for the alpha parameter in the implicit preference formulation (>= 0).
@@ -102,7 +102,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
ParamValidators.gtEq(0))
/** @group getParam */
- def getAlpha: Double = getOrDefault(alpha)
+ def getAlpha: Double = $(alpha)
/**
* Param for the column name for user ids.
@@ -112,7 +112,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
val userCol = new Param[String](this, "userCol", "column name for user ids")
/** @group getParam */
- def getUserCol: String = getOrDefault(userCol)
+ def getUserCol: String = $(userCol)
/**
* Param for the column name for item ids.
@@ -122,7 +122,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
val itemCol = new Param[String](this, "itemCol", "column name for item ids")
/** @group getParam */
- def getItemCol: String = getOrDefault(itemCol)
+ def getItemCol: String = $(itemCol)
/**
* Param for the column name for ratings.
@@ -132,7 +132,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
val ratingCol = new Param[String](this, "ratingCol", "column name for ratings")
/** @group getParam */
- def getRatingCol: String = getOrDefault(ratingCol)
+ def getRatingCol: String = $(ratingCol)
/**
* Param for whether to apply nonnegativity constraints.
@@ -143,7 +143,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
this, "nonnegative", "whether to use nonnegative constraint for least squares")
/** @group getParam */
- def getNonnegative: Boolean = getOrDefault(nonnegative)
+ def getNonnegative: Boolean = $(nonnegative)
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
@@ -152,19 +152,17 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
/**
* Validates and transforms the input schema.
* @param schema input schema
- * @param paramMap extra params
* @return output schema
*/
- protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- assert(schema(map(userCol)).dataType == IntegerType)
- assert(schema(map(itemCol)).dataType== IntegerType)
- val ratingType = schema(map(ratingCol)).dataType
- assert(ratingType == FloatType || ratingType == DoubleType)
- val predictionColName = map(predictionCol)
- assert(!schema.fieldNames.contains(predictionColName),
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ require(schema($(userCol)).dataType == IntegerType)
+ require(schema($(itemCol)).dataType== IntegerType)
+ val ratingType = schema($(ratingCol)).dataType
+ require(ratingType == FloatType || ratingType == DoubleType)
+ val predictionColName = $(predictionCol)
+ require(!schema.fieldNames.contains(predictionColName),
s"Prediction column $predictionColName already exists.")
- val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false)
+ val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false)
StructType(newFields)
}
}
@@ -174,7 +172,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
*/
class ALSModel private[ml] (
override val parent: ALS,
- override val fittingParamMap: ParamMap,
k: Int,
userFactors: RDD[(Int, Array[Float])],
itemFactors: RDD[(Int, Array[Float])])
@@ -183,9 +180,8 @@ class ALSModel private[ml] (
/** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ override def transform(dataset: DataFrame): DataFrame = {
import dataset.sqlContext.implicits._
- val map = extractParamMap(paramMap)
val users = userFactors.toDF("id", "features")
val items = itemFactors.toDF("id", "features")
@@ -199,13 +195,13 @@ class ALSModel private[ml] (
}
}
dataset
- .join(users, dataset(map(userCol)) === users("id"), "left")
- .join(items, dataset(map(itemCol)) === items("id"), "left")
- .select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
+ .join(users, dataset($(userCol)) === users("id"), "left")
+ .join(items, dataset($(itemCol)) === items("id"), "left")
+ .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol)))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
@@ -292,25 +288,22 @@ class ALS extends Estimator[ALSModel] with ALSParams {
this
}
- override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
- val map = extractParamMap(paramMap)
+ override def fit(dataset: DataFrame): ALSModel = {
val ratings = dataset
- .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType))
+ .select(col($(userCol)), col($(itemCol)), col($(ratingCol)).cast(FloatType))
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
- val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
- numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
- maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
- alpha = map(alpha), nonnegative = map(nonnegative),
- checkpointInterval = map(checkpointInterval))
- val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
- Params.inheritValues(map, this, model)
- model
+ val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
+ numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
+ maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
+ alpha = $(alpha), nonnegative = $(nonnegative),
+ checkpointInterval = $(checkpointInterval))
+ copyValues(new ALSModel(this, $(rank), userFactors, itemFactors))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 756725a64b..b07c26fe79 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
@@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* :: AlphaComponent ::
*
@@ -63,15 +62,13 @@ final class DecisionTreeRegressor
override def setImpurity(value: String): this.type = super.setImpurity(value)
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): DecisionTreeRegressionModel = {
+ override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures)
val oldModel = OldDecisionTree.train(oldDataset, strategy)
- DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ DecisionTreeRegressionModel.fromOld(oldModel, this, categoricalFeatures)
}
/** (private[ml]) Create a Strategy instance to use with the old API. */
@@ -96,7 +93,6 @@ object DecisionTreeRegressor {
@AlphaComponent
final class DecisionTreeRegressionModel private[ml] (
override val parent: DecisionTreeRegressor,
- override val fittingParamMap: ParamMap,
override val rootNode: Node)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
with DecisionTreeModel with Serializable {
@@ -108,10 +104,8 @@ final class DecisionTreeRegressionModel private[ml] (
rootNode.predict(features)
}
- override protected def copy(): DecisionTreeRegressionModel = {
- val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
+ copyValues(new DecisionTreeRegressionModel(parent, rootNode), extra)
}
override def toString: String = {
@@ -130,12 +124,11 @@ private[ml] object DecisionTreeRegressionModel {
def fromOld(
oldModel: OldDecisionTreeModel,
parent: DecisionTreeRegressor,
- fittingParamMap: ParamMap,
categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = {
require(oldModel.algo == OldAlgo.Regression,
s"Cannot convert non-regression DecisionTreeModel (old API) to" +
s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
- new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+ new DecisionTreeRegressionModel(parent, rootNode)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 76c9837693..bc796958e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -23,20 +23,18 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.param.{Params, ParamMap, Param}
+import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
- SquaredError => OldSquaredError}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* :: AlphaComponent ::
*
@@ -111,7 +109,7 @@ final class GBTRegressor
def setLossType(value: String): this.type = set(lossType, value)
/** @group getParam */
- def getLossType: String = getOrDefault(lossType).toLowerCase
+ def getLossType: String = $(lossType).toLowerCase
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
@@ -124,16 +122,14 @@ final class GBTRegressor
}
}
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): GBTRegressionModel = {
+ override protected def train(dataset: DataFrame): GBTRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val oldGBT = new OldGBT(boostingStrategy)
val oldModel = oldGBT.run(oldDataset)
- GBTRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures)
}
}
@@ -155,7 +151,6 @@ object GBTRegressor {
@AlphaComponent
final class GBTRegressionModel(
override val parent: GBTRegressor,
- override val fittingParamMap: ParamMap,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double])
extends PredictionModel[Vector, GBTRegressionModel]
@@ -178,10 +173,8 @@ final class GBTRegressionModel(
if (prediction > 0.0) 1.0 else 0.0
}
- override protected def copy(): GBTRegressionModel = {
- val m = new GBTRegressionModel(parent, fittingParamMap, _trees, _treeWeights)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): GBTRegressionModel = {
+ copyValues(new GBTRegressionModel(parent, _trees, _treeWeights), extra)
}
override def toString: String = {
@@ -200,14 +193,13 @@ private[ml] object GBTRegressionModel {
def fromOld(
oldModel: OldGBTModel,
parent: GBTRegressor,
- fittingParamMap: ParamMap,
categoricalFeatures: Map[Int, Int]): GBTRegressionModel = {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
- DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+ DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new GBTRegressionModel(parent, fittingParamMap, newTrees, oldModel.treeWeights)
+ new GBTRegressionModel(parent, newTrees, oldModel.treeWeights)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 0b81c48466..66c475f2d9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -19,22 +19,22 @@ package org.apache.spark.ml.regression
import scala.collection.mutable
-import breeze.linalg.{norm => brzNorm, DenseVector => BDV}
-import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
-import breeze.optimize.{CachedDiffFunction, DiffFunction}
+import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
+ OWLQN => BreezeOWLQN}
+import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.param.{Params, ParamMap}
-import org.apache.spark.ml.param.shared.{HasTol, HasElasticNetParam, HasMaxIter, HasRegParam}
-import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter
-import org.apache.spark.Logging
/**
* Params for linear regression.
@@ -96,9 +96,9 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
def setTol(value: Double): this.type = set(tol, value)
setDefault(tol -> 1E-6)
- override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
+ override protected def train(dataset: DataFrame): LinearRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist instances.
- val instances = extractLabeledPoints(dataset, paramMap).map {
+ val instances = extractLabeledPoints(dataset).map {
case LabeledPoint(label: Double, features: Vector) => (label, features)
}
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -125,7 +125,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
s"and the intercept will be the mean of the label; as a result, training is not needed.")
if (handlePersistence) instances.unpersist()
- return new LinearRegressionModel(this, paramMap, Vectors.sparse(numFeatures, Seq()), yMean)
+ return new LinearRegressionModel(this, Vectors.sparse(numFeatures, Seq()), yMean)
}
val featuresMean = summarizer.mean.toArray
@@ -133,17 +133,17 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
// Since we implicitly do the feature scaling when we compute the cost function
// to improve the convergence, the effective regParam will be changed.
- val effectiveRegParam = paramMap(regParam) / yStd
- val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam
- val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam
+ val effectiveRegParam = $(regParam) / yStd
+ val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
+ val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
featuresStd, featuresMean, effectiveL2RegParam)
- val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
- new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol))
+ val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
+ new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
} else {
- new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol))
+ new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegParam, $(tol))
}
val initialWeights = Vectors.zeros(numFeatures)
@@ -178,7 +178,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
if (handlePersistence) instances.unpersist()
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
- new LinearRegressionModel(this, paramMap, weights.compressed, intercept)
+ new LinearRegressionModel(this, weights.compressed, intercept)
}
}
@@ -190,7 +190,6 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
@AlphaComponent
class LinearRegressionModel private[ml] (
override val parent: LinearRegression,
- override val fittingParamMap: ParamMap,
val weights: Vector,
val intercept: Double)
extends RegressionModel[Vector, LinearRegressionModel]
@@ -200,10 +199,8 @@ class LinearRegressionModel private[ml] (
dot(features, weights) + intercept
}
- override protected def copy(): LinearRegressionModel = {
- val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept)
- Params.inheritValues(extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): LinearRegressionModel = {
+ copyValues(new LinearRegressionModel(parent, weights, intercept), extra)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 2171ef3d32..0468a1be1b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -20,18 +20,17 @@ package org.apache.spark.ml.regression
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams}
-import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* :: AlphaComponent ::
*
@@ -77,17 +76,15 @@ final class RandomForestRegressor
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): RandomForestRegressionModel = {
+ override protected def train(dataset: DataFrame): RandomForestRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
val oldModel = OldRandomForest.trainRegressor(
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
- RandomForestRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
}
}
@@ -110,7 +107,6 @@ object RandomForestRegressor {
@AlphaComponent
final class RandomForestRegressionModel private[ml] (
override val parent: RandomForestRegressor,
- override val fittingParamMap: ParamMap,
private val _trees: Array[DecisionTreeRegressionModel])
extends PredictionModel[Vector, RandomForestRegressionModel]
with TreeEnsembleModel with Serializable {
@@ -132,10 +128,8 @@ final class RandomForestRegressionModel private[ml] (
_trees.map(_.rootNode.predict(features)).sum / numTrees
}
- override protected def copy(): RandomForestRegressionModel = {
- val m = new RandomForestRegressionModel(parent, fittingParamMap, _trees)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): RandomForestRegressionModel = {
+ copyValues(new RandomForestRegressionModel(parent, _trees), extra)
}
override def toString: String = {
@@ -154,14 +148,13 @@ private[ml] object RandomForestRegressionModel {
def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestRegressor,
- fittingParamMap: ParamMap,
categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
- DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+ DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent, fittingParamMap, newTrees)
+ new RandomForestRegressionModel(parent, newTrees)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
index d679085eea..c6b3327db6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.regression
-import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index d1ad0893cd..cee2aa6e85 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -39,7 +39,7 @@ private[ml] trait CrossValidatorParams extends Params {
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
/** @group getParam */
- def getEstimator: Estimator[_] = getOrDefault(estimator)
+ def getEstimator: Estimator[_] = $(estimator)
/**
* param for estimator param maps
@@ -49,7 +49,7 @@ private[ml] trait CrossValidatorParams extends Params {
new Param(this, "estimatorParamMaps", "param maps for the estimator")
/** @group getParam */
- def getEstimatorParamMaps: Array[ParamMap] = getOrDefault(estimatorParamMaps)
+ def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
/**
* param for the evaluator for selection
@@ -58,7 +58,7 @@ private[ml] trait CrossValidatorParams extends Params {
val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
/** @group getParam */
- def getEvaluator: Evaluator = getOrDefault(evaluator)
+ def getEvaluator: Evaluator = $(evaluator)
/**
* Param for number of folds for cross validation. Must be >= 2.
@@ -69,7 +69,7 @@ private[ml] trait CrossValidatorParams extends Params {
"number of folds for cross validation (>= 2)", ParamValidators.gtEq(2))
/** @group getParam */
- def getNumFolds: Int = getOrDefault(numFolds)
+ def getNumFolds: Int = $(numFolds)
setDefault(numFolds -> 3)
}
@@ -95,23 +95,22 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
/** @group setParam */
def setNumFolds(value: Int): this.type = set(numFolds, value)
- override def validate(paramMap: ParamMap): Unit = {
+ override def validateParams(paramMap: ParamMap): Unit = {
getEstimatorParamMaps.foreach { eMap =>
- getEstimator.validate(eMap ++ paramMap)
+ getEstimator.validateParams(eMap ++ paramMap)
}
}
- override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = {
- val map = extractParamMap(paramMap)
+ override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
- transformSchema(dataset.schema, paramMap, logging = true)
+ transformSchema(dataset.schema, logging = true)
val sqlCtx = dataset.sqlContext
- val est = map(estimator)
- val eval = map(evaluator)
- val epm = map(estimatorParamMaps)
+ val est = $(estimator)
+ val eval = $(evaluator)
+ val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
- val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
+ val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0)
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
@@ -121,27 +120,24 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
trainingDataset.unpersist()
var i = 0
while (i < numModels) {
- val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map)
+ val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
metrics(i) += metric
i += 1
}
validationDataset.unpersist()
}
- f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1)
+ f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
- val cvModel = new CrossValidatorModel(this, map, bestModel)
- Params.inheritValues(map, this, cvModel)
- cvModel
+ copyValues(new CrossValidatorModel(this, bestModel))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- map(estimator).transformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ $(estimator).transformSchema(schema)
}
}
@@ -152,19 +148,18 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
@AlphaComponent
class CrossValidatorModel private[ml] (
override val parent: CrossValidator,
- override val fittingParamMap: ParamMap,
val bestModel: Model[_])
extends Model[CrossValidatorModel] with CrossValidatorParams {
- override def validate(paramMap: ParamMap): Unit = {
- bestModel.validate(paramMap)
+ override def validateParams(paramMap: ParamMap): Unit = {
+ bestModel.validateParams(paramMap)
}
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- bestModel.transform(dataset, paramMap)
+ override def transform(dataset: DataFrame): DataFrame = {
+ bestModel.transform(dataset)
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- bestModel.transformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ bestModel.transformSchema(schema)
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 3f8e59de0f..7e7189a2b1 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -84,9 +84,10 @@ public class JavaLogisticRegressionSuite implements Serializable {
.setThreshold(0.6)
.setProbabilityCol("myProbability");
LogisticRegressionModel model = lr.fit(dataset);
- assert(model.fittingParamMap().apply(lr.maxIter()).equals(10));
- assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
- assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6));
+ LogisticRegression parent = model.parent();
+ assert(parent.getMaxIter() == 10);
+ assert(parent.getRegParam() == 1.0);
+ assert(parent.getThreshold() == 0.6);
assert(model.getThreshold() == 0.6);
// Modify model params, and check that the params worked.
@@ -109,9 +110,10 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
- assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5));
- assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
- assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4));
+ LogisticRegression parent2 = model2.parent();
+ assert(parent2.getMaxIter() == 5);
+ assert(parent2.getRegParam() == 0.1);
+ assert(parent2.getThreshold() == 0.4);
assert(model2.getThreshold() == 0.4);
assert(model2.getProbabilityCol().equals("theProb"));
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 0cc36c8d56..a82b86d560 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -23,14 +23,15 @@ import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
- .generateLogisticInputAsList;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
public class JavaLinearRegressionSuite implements Serializable {
@@ -65,8 +66,8 @@ public class JavaLinearRegressionSuite implements Serializable {
DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
predictions.collect();
// Check defaults
- assert(model.getFeaturesCol().equals("features"));
- assert(model.getPredictionCol().equals("prediction"));
+ assertEquals("features", model.getFeaturesCol());
+ assertEquals("prediction", model.getPredictionCol());
}
@Test
@@ -76,14 +77,16 @@ public class JavaLinearRegressionSuite implements Serializable {
.setMaxIter(10)
.setRegParam(1.0);
LinearRegressionModel model = lr.fit(dataset);
- assert(model.fittingParamMap().apply(lr.maxIter()).equals(10));
- assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
+ LinearRegression parent = model.parent();
+ assertEquals(10, parent.getMaxIter());
+ assertEquals(1.0, parent.getRegParam(), 0.0);
// Call fit() with new params, and check as many params as we can.
LinearRegressionModel model2 =
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
- assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5));
- assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
- assert(model2.getPredictionCol().equals("thePred"));
+ LinearRegression parent2 = model2.parent();
+ assertEquals(5, parent2.getMaxIter());
+ assertEquals(0.1, parent2.getRegParam(), 0.0);
+ assertEquals("thePred", model2.getPredictionCol());
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 0bb6b489f2..08eeca53f0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -68,8 +68,8 @@ public class JavaCrossValidatorSuite implements Serializable {
.setEvaluator(eval)
.setNumFolds(3);
CrossValidatorModel cvModel = cv.fit(dataset);
- ParamMap bestParamMap = cvModel.bestModel().fittingParamMap();
- Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam()));
- Assert.assertEquals(10, bestParamMap.apply(lr.maxIter()));
+ LogisticRegression parent = (LogisticRegression) cvModel.bestModel().parent();
+ Assert.assertEquals(0.001, parent.getRegParam(), 0.0);
+ Assert.assertEquals(10, parent.getMaxIter());
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 2f175fb117..2b04a30347 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -42,30 +42,32 @@ class PipelineSuite extends FunSuite {
val dataset3 = mock[DataFrame]
val dataset4 = mock[DataFrame]
- when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0)
- when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1)
+ when(estimator0.copy(any[ParamMap])).thenReturn(estimator0)
+ when(model0.copy(any[ParamMap])).thenReturn(model0)
+ when(transformer1.copy(any[ParamMap])).thenReturn(transformer1)
+ when(estimator2.copy(any[ParamMap])).thenReturn(estimator2)
+ when(model2.copy(any[ParamMap])).thenReturn(model2)
+ when(transformer3.copy(any[ParamMap])).thenReturn(transformer3)
+
+ when(estimator0.fit(meq(dataset0))).thenReturn(model0)
+ when(model0.transform(meq(dataset0))).thenReturn(dataset1)
when(model0.parent).thenReturn(estimator0)
- when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2)
- when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2)
- when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3)
+ when(transformer1.transform(meq(dataset1))).thenReturn(dataset2)
+ when(estimator2.fit(meq(dataset2))).thenReturn(model2)
+ when(model2.transform(meq(dataset2))).thenReturn(dataset3)
when(model2.parent).thenReturn(estimator2)
- when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4)
+ when(transformer3.transform(meq(dataset3))).thenReturn(dataset4)
val pipeline = new Pipeline()
.setStages(Array(estimator0, transformer1, estimator2, transformer3))
val pipelineModel = pipeline.fit(dataset0)
- assert(pipelineModel.stages.size === 4)
+ assert(pipelineModel.stages.length === 4)
assert(pipelineModel.stages(0).eq(model0))
assert(pipelineModel.stages(1).eq(transformer1))
assert(pipelineModel.stages(2).eq(model2))
assert(pipelineModel.stages(3).eq(transformer3))
- assert(pipelineModel.getModel(estimator0).eq(model0))
- assert(pipelineModel.getModel(estimator2).eq(model2))
- intercept[NoSuchElementException] {
- pipelineModel.getModel(mock[Estimator[MyModel]])
- }
val output = pipelineModel.transform(dataset0)
assert(output.eq(dataset4))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 9b31adecdc..03af4ecd7a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -267,8 +267,8 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite {
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
- val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent,
- newTree.fittingParamMap, categoricalFeatures)
+ val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
+ oldTree, newTree.parent, categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index e6ccc2c93c..16c758b82c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -129,8 +129,8 @@ private object GBTClassifierSuite {
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
val newModel = gbt.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
- val oldModelAsNew = GBTClassificationModel.fromOld(oldModel, newModel.parent,
- newModel.fittingParamMap, categoricalFeatures)
+ val oldModelAsNew = GBTClassificationModel.fromOld(
+ oldModel, newModel.parent, categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 35d8c2e16c..6dd1fdf055 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -74,9 +74,10 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
.setThreshold(0.6)
.setProbabilityCol("myProbability")
val model = lr.fit(dataset)
- assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
- assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
- assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
+ val parent = model.parent
+ assert(parent.getMaxIter === 10)
+ assert(parent.getRegParam === 1.0)
+ assert(parent.getThreshold === 0.6)
assert(model.getThreshold === 0.6)
// Modify model params, and check that the params worked.
@@ -99,9 +100,10 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
// Call fit() with new params, and check as many params as we can.
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
lr.probabilityCol -> "theProb")
- assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
- assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
- assert(model2.fittingParamMap.get(lr.threshold).get === 0.4)
+ val parent2 = model2.parent
+ assert(parent2.getMaxIter === 5)
+ assert(parent2.getRegParam === 0.1)
+ assert(parent2.getThreshold === 0.4)
assert(model2.getThreshold === 0.4)
assert(model2.getProbabilityCol == "theProb")
}
@@ -117,7 +119,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
val results = model.transform(dataset)
// Compare rawPrediction with probability
- results.select("rawPrediction", "probability").collect().map {
+ results.select("rawPrediction", "probability").collect().foreach {
case Row(raw: Vector, prob: Vector) =>
assert(raw.size === 2)
assert(prob.size === 2)
@@ -127,7 +129,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
}
// Compare prediction with probability
- results.select("prediction", "probability").collect().map {
+ results.select("prediction", "probability").collect().foreach {
case Row(pred: Double, prob: Vector) =>
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
assert(pred == predFromProb)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index ed41a9664f..c41def9330 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -159,8 +159,8 @@ private object RandomForestClassifierSuite {
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val newModel = rf.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
- val oldModelAsNew = RandomForestClassificationModel.fromOld(oldModel, newModel.parent,
- newModel.fittingParamMap, categoricalFeatures)
+ val oldModelAsNew = RandomForestClassificationModel.fromOld(
+ oldModel, newModel.parent, categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index f8852606ab..6056e7d3f6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -122,19 +122,21 @@ class ParamsSuite extends FunSuite {
assert(solver.getParam("inputCol").eq(inputCol))
assert(solver.getParam("maxIter").eq(maxIter))
+ assert(solver.hasParam("inputCol"))
+ assert(!solver.hasParam("abc"))
intercept[NoSuchElementException] {
solver.getParam("abc")
}
intercept[IllegalArgumentException] {
- solver.validate()
+ solver.validateParams()
}
- solver.validate(ParamMap(inputCol -> "input"))
+ solver.validateParams(ParamMap(inputCol -> "input"))
solver.setInputCol("input")
assert(solver.isSet(inputCol))
assert(solver.isDefined(inputCol))
assert(solver.getInputCol === "input")
- solver.validate()
+ solver.validateParams()
intercept[IllegalArgumentException] {
ParamMap(maxIter -> -10)
}
@@ -144,6 +146,11 @@ class ParamsSuite extends FunSuite {
solver.clearMaxIter()
assert(!solver.isSet(maxIter))
+
+ val copied = solver.copy(ParamMap(solver.maxIter -> 50))
+ assert(copied.uid !== solver.uid)
+ assert(copied.getInputCol === solver.getInputCol)
+ assert(copied.getMaxIter === 50)
}
test("ParamValidate") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 6f9c9cb516..dc16073640 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -23,15 +23,19 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}
class TestParams extends Params with HasMaxIter with HasInputCol {
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
+
def setInputCol(value: String): this.type = { set(inputCol, value); this }
setDefault(maxIter -> 10)
- override def validate(paramMap: ParamMap): Unit = {
- val m = extractParamMap(paramMap)
- // Note: maxIter is validated when it is set.
- require(m.contains(inputCol))
+ def clearMaxIter(): this.type = clear(maxIter)
+
+ override def validateParams(): Unit = {
+ super.validateParams()
+ require(isDefined(inputCol))
}
- def clearMaxIter(): this.type = clear(maxIter)
+ override def copy(extra: ParamMap): TestParams = {
+ super.copy(extra).asInstanceOf[TestParams]
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index c87a171b4b..5aa81b44dd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -84,8 +84,8 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite {
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newTree = dt.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
- val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent,
- newTree.fittingParamMap, categoricalFeatures)
+ val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
+ oldTree, newTree.parent, categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 4aec36948a..25b36ab08b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -130,8 +130,7 @@ private object GBTRegressorSuite {
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = gbt.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
- val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent,
- newModel.fittingParamMap, categoricalFeatures)
+ val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c6dc1cc29b..45f09f4fda 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -115,8 +115,8 @@ private object RandomForestRegressorSuite extends FunSuite {
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = rf.fit(newData)
// Use parent, fittingParamMap from newTree since these are not checked anyways.
- val oldModelAsNew = RandomForestRegressionModel.fromOld(oldModel, newModel.parent,
- newModel.fittingParamMap, categoricalFeatures)
+ val oldModelAsNew = RandomForestRegressionModel.fromOld(
+ oldModel, newModel.parent, categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 761ea821ef..05313d440f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -49,8 +49,8 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
.setEvaluator(eval)
.setNumFolds(3)
val cvModel = cv.fit(dataset)
- val bestParamMap = cvModel.bestModel.fittingParamMap
- assert(bestParamMap(lr.regParam) === 0.001)
- assert(bestParamMap(lr.maxIter) === 10)
+ val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
+ assert(parent.getRegParam === 0.001)
+ assert(parent.getMaxIter === 10)
}
}