aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-06-19 09:46:51 -0700
committerXiangrui Meng <meng@databricks.com>2015-06-19 09:46:51 -0700
commit43c7ec6384e51105dedf3a53354b6a3732cc27b2 (patch)
treea745061321a8989141ddf7e4bad7e7c3e8b3806b /mllib
parent47af7c1ebfdbd7637f626ab07bf2bda6534f37ea (diff)
downloadspark-43c7ec6384e51105dedf3a53354b6a3732cc27b2.tar.gz
spark-43c7ec6384e51105dedf3a53354b6a3732cc27b2.tar.bz2
spark-43c7ec6384e51105dedf3a53354b6a3732cc27b2.zip
[SPARK-8151] [MLLIB] pipeline components should correctly implement copy
Otherwise, extra params get ignored in `PipelineModel.transform`. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6622 from mengxr/SPARK-8087 and squashes the following commits: 0e4c8c4 [Xiangrui Meng] fix merge issues 26fc1f0 [Xiangrui Meng] address comments e607a04 [Xiangrui Meng] merge master b85b57e [Xiangrui Meng] fix examples/compile d6f7891 [Xiangrui Meng] rename defaultCopyWithParams to defaultCopy 84ec278 [Xiangrui Meng] remove setter checks due to generics 2cf2ed0 [Xiangrui Meng] snapshot 291814f [Xiangrui Meng] OneVsRest.copy 1dfe3bd [Xiangrui Meng] PipelineModel.copy should copy stages
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Model.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala22
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala5
60 files changed, 343 insertions, 55 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index e9a5d7c0e7..57e416591d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
paramMaps.map(fit(dataset, _))
}
- override def copy(extra: ParamMap): Estimator[M] = {
- super.copy(extra).asInstanceOf[Estimator[M]]
- }
+ override def copy(extra: ParamMap): Estimator[M]
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index 186bf7ae7a..252acc1565 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -45,8 +45,5 @@ abstract class Model[M <: Model[M]] extends Transformer {
/** Indicates whether this [[Model]] has a corresponding parent. */
def hasParent: Boolean = parent != null
- override def copy(extra: ParamMap): M = {
- // The default implementation of Params.copy doesn't work for models.
- throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
- }
+ override def copy(extra: ParamMap): M
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index a9bd28df71..a1f3851d80 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -66,9 +66,7 @@ abstract class PipelineStage extends Params with Logging {
outputSchema
}
- override def copy(extra: ParamMap): PipelineStage = {
- super.copy(extra).asInstanceOf[PipelineStage]
- }
+ override def copy(extra: ParamMap): PipelineStage
}
/**
@@ -198,6 +196,6 @@ class PipelineModel private[ml] (
}
override def copy(extra: ParamMap): PipelineModel = {
- new PipelineModel(uid, stages)
+ new PipelineModel(uid, stages.map(_.copy(extra)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index e752b81a14..edaa2afb79 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -90,9 +90,7 @@ abstract class Predictor[
copyValues(train(dataset).setParent(this))
}
- override def copy(extra: ParamMap): Learner = {
- super.copy(extra).asInstanceOf[Learner]
- }
+ override def copy(extra: ParamMap): Learner
/**
* Train a model using the given dataset and parameters.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index f07f733a5d..3c7bcf7590 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -67,9 +67,7 @@ abstract class Transformer extends PipelineStage {
*/
def transform(dataset: DataFrame): DataFrame
- override def copy(extra: ParamMap): Transformer = {
- super.copy(extra).asInstanceOf[Transformer]
- }
+ override def copy(extra: ParamMap): Transformer
}
/**
@@ -120,4 +118,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
dataset.withColumn($(outputCol),
callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
}
+
+ override def copy(extra: ParamMap): T = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 263d580fe2..14c285dbfc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 8030e0728a..2dc1824964 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -86,6 +86,8 @@ final class DecisionTreeClassifier(override val uid: String)
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
subsamplingRate = 1.0)
}
+
+ override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 62f4b51f77..554e3b8e05 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -141,6 +141,8 @@ final class GBTClassifier(override val uid: String)
val oldModel = oldGBT.run(oldDataset)
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
+
+ override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index f136bcee9c..2e6eedd45a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -220,6 +220,8 @@ class LogisticRegression(override val uid: String)
new LogisticRegressionModel(uid, weights.compressed, intercept)
}
+
+ override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 825f9ed1b5..b657882f8a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -24,7 +24,7 @@ import scala.language.existentials
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.Param
+import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}
@@ -133,6 +133,12 @@ final class OneVsRestModel private[ml] (
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
.drop(accColName)
}
+
+ override def copy(extra: ParamMap): OneVsRestModel = {
+ val copied = new OneVsRestModel(
+ uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
+ copyValues(copied, extra)
+ }
}
/**
@@ -209,4 +215,12 @@ final class OneVsRest(override val uid: String)
val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
copyValues(model)
}
+
+ override def copy(extra: ParamMap): OneVsRest = {
+ val copied = defaultCopy(extra).asInstanceOf[OneVsRest]
+ if (isDefined(classifier)) {
+ copied.setClassifier($(classifier).copy(extra))
+ }
+ copied
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 852a67e066..d3c67494a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -97,6 +97,8 @@ final class RandomForestClassifier(override val uid: String)
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
+
+ override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index f695ddaeef..4a82b77f0e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -79,4 +79,6 @@ class BinaryClassificationEvaluator(override val uid: String)
metrics.unpersist()
metric
}
+
+ override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
index 61e937e693..e56c946a06 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
@@ -46,7 +46,5 @@ abstract class Evaluator extends Params {
*/
def evaluate(dataset: DataFrame): Double
- override def copy(extra: ParamMap): Evaluator = {
- super.copy(extra).asInstanceOf[Evaluator]
- }
+ override def copy(extra: ParamMap): Evaluator
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index abb1b35bed..8670e9679d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.param.{Param, ParamValidators}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
@@ -80,4 +80,6 @@ final class RegressionEvaluator(override val uid: String)
}
metric
}
+
+ override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index b06122d733..46314854d5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -83,4 +83,6 @@ final class Binarizer(override val uid: String)
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}
+
+ override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index a3d1f6f65c..67e4785bc3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -89,6 +89,8 @@ final class Bucketizer(override val uid: String)
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
+
+ override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra)
}
private[feature] object Bucketizer {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index 1e758cb775..a359cb8f37 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.Param
+import org.apache.spark.ml.param.{ParamMap, Param}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index f936aef80f..319d23e46c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.param.{IntParam, ParamValidators}
+import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
@@ -74,4 +74,6 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
}
+
+ override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 376b84530c..ecde808105 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -45,9 +45,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
/** @group getParam */
def getMinDocFreq: Int = $(minDocFreq)
- /** @group setParam */
- def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
-
/**
* Validate and transform the input schema.
*/
@@ -72,6 +69,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
+
override def fit(dataset: DataFrame): IDFModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
@@ -82,6 +82,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): IDF = defaultCopy(extra)
}
/**
@@ -109,4 +111,9 @@ class IDFModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): IDFModel = {
+ val copied = new IDFModel(uid, idfModel)
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 8f34878c8d..3825942795 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -165,4 +165,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
}
+
+ override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index 442e958202..d85e468562 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{IntParam, ParamValidators}
+import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.types.DataType
@@ -61,6 +61,8 @@ class PolynomialExpansion(override val uid: String)
}
override protected def outputDataType: DataType = new VectorUDT()
+
+ override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index b0fd06d84f..ca3c1cfb56 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -92,6 +92,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}
+
+ override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
}
/**
@@ -125,4 +127,9 @@ class StandardScalerModel private[ml] (
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}
+
+ override def copy(extra: ParamMap): StandardScalerModel = {
+ val copied = new StandardScalerModel(uid, scaler)
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index f4e2507575..bf7be363b8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -83,6 +83,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)
}
/**
@@ -144,4 +146,9 @@ class StringIndexerModel private[ml] (
schema
}
}
+
+ override def copy(extra: ParamMap): StringIndexerModel = {
+ val copied = new StringIndexerModel(uid, labels)
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 21c15b6c33..5f9f57a2eb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -43,6 +43,8 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
}
override protected def outputDataType: DataType = new ArrayType(StringType, false)
+
+ override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
}
/**
@@ -112,4 +114,6 @@ class RegexTokenizer(override val uid: String)
}
override protected def outputDataType: DataType = new ArrayType(StringType, false)
+
+ override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 229ee27ec5..9f83c2ee16 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
@@ -117,6 +118,8 @@ class VectorAssembler(override val uid: String)
}
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
}
+
+ override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
}
private object VectorAssembler {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 1d0f23b4fb..f4854a5e4b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
+import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
@@ -131,6 +131,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
SchemaUtils.checkColumnType(schema, $(inputCol), dataType)
SchemaUtils.appendColumn(schema, $(outputCol), dataType)
}
+
+ override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra)
}
private object VectorIndexer {
@@ -399,4 +401,9 @@ class VectorIndexerModel private[ml] (
val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
newAttributeGroup.toStructField()
}
+
+ override def copy(extra: ParamMap): VectorIndexerModel = {
+ val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 36f19509f0..6ea6590956 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -132,6 +132,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra)
}
/**
@@ -180,4 +182,9 @@ class Word2VecModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): Word2VecModel = {
+ val copied = new Word2VecModel(uid, wordVectors)
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index ba94d6a3a8..15ebad8838 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -492,13 +492,20 @@ trait Params extends Identifiable with Serializable {
/**
* Creates a copy of this instance with the same UID and some extra params.
- * The default implementation tries to create a new instance with the same UID.
+ * Subclasses should implement this method and set the return type properly.
+ *
+ * @see [[defaultCopy()]]
+ */
+ def copy(extra: ParamMap): Params
+
+ /**
+ * Default implementation of copy with extra params.
+ * It tries to create a new instance with the same UID.
* Then it copies the embedded and extra parameters over and returns the new instance.
- * Subclasses should override this method if the default approach is not sufficient.
*/
- def copy(extra: ParamMap): Params = {
+ protected final def defaultCopy[T <: Params](extra: ParamMap): T = {
val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
- copyValues(that, extra)
+ copyValues(that, extra).asInstanceOf[T]
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index df009d855e..2e44cd4cc6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -216,6 +216,11 @@ class ALSModel private[ml] (
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
+
+ override def copy(extra: ParamMap): ALSModel = {
+ val copied = new ALSModel(uid, rank, userFactors, itemFactors)
+ copyValues(copied, extra)
+ }
}
@@ -330,6 +335,8 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): ALS = defaultCopy(extra)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 43b68e7bb2..be1f8063d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -76,6 +76,8 @@ final class DecisionTreeRegressor(override val uid: String)
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
subsamplingRate = 1.0)
}
+
+ override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index b7e374bb6c..036e3acb07 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -131,6 +131,8 @@ final class GBTRegressor(override val uid: String)
val oldModel = oldGBT.run(oldDataset)
GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures)
}
+
+ override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 70cd8e9e87..01306545fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -186,6 +186,8 @@ class LinearRegression(override val uid: String)
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
}
+
+ override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 49a1f7ce8c..21c59061a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -86,6 +86,8 @@ final class RandomForestRegressor(override val uid: String)
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
}
+
+ override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index cb29392e8b..e2444ab65b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -149,6 +149,17 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
est.copy(paramMap).validateParams()
}
}
+
+ override def copy(extra: ParamMap): CrossValidator = {
+ val copied = defaultCopy(extra).asInstanceOf[CrossValidator]
+ if (copied.isDefined(estimator)) {
+ copied.setEstimator(copied.getEstimator.copy(extra))
+ }
+ if (copied.isDefined(evaluator)) {
+ copied.setEvaluator(copied.getEvaluator.copy(extra))
+ }
+ copied
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
index efbfeb4059..3fab7ea79b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -159,7 +159,7 @@ private object IDF {
* Represents an IDF model that can transform term frequency vectors.
*/
@Experimental
-class IDFModel private[mllib] (val idf: Vector) extends Serializable {
+class IDFModel private[spark] (val idf: Vector) extends Serializable {
/**
* Transforms term frequency (TF) vectors to TF-IDF vectors.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 51546d41c3..f087d06d2a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -431,7 +431,7 @@ class Word2Vec extends Serializable with Logging {
* Word2Vec model
*/
@Experimental
-class Word2VecModel private[mllib] (
+class Word2VecModel private[spark] (
model: Map[String, Array[Float]]) extends Serializable with Saveable {
// wordList: Ordered list of words obtained from model.
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index ff5929235a..3ae09d39ef 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -102,4 +102,9 @@ public class JavaTestParams extends JavaParams {
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
}
+
+ @Override
+ public JavaTestParams copy(ParamMap extra) {
+ return defaultCopy(extra);
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 29394fefcb..63d2fa31c7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -24,6 +24,7 @@ import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame
@@ -84,6 +85,15 @@ class PipelineSuite extends SparkFunSuite {
}
}
+ test("PipelineModel.copy") {
+ val hashingTF = new HashingTF()
+ .setNumFeatures(100)
+ val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
+ val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
+ require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
+ "copy should handle extra stage params")
+ }
+
test("pipeline model constructors") {
val transform0 = mock[Transformer]
val model1 = mock[MyModel]
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index ae40b0b8ff..73b4805c4c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -19,15 +19,15 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
- DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeClassifierSuite.compareAPIs
@@ -55,6 +55,12 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
}
+ test("params") {
+ ParamsSuite.checkParams(new DecisionTreeClassifier)
+ val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
+ ParamsSuite.checkParams(model)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 1302da3c37..82c345491b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -51,6 +54,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
}
+ test("params") {
+ ParamsSuite.checkParams(new GBTClassifier)
+ val model = new GBTClassificationModel("gbtc",
+ Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
+ Array(1.0))
+ ParamsSuite.checkParams(model)
+ }
+
test("Binary classification with continuous features: Log Loss") {
val categoricalFeatures = Map.empty[Int, Int]
testCombinations.foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a755cac3ea..5a6265ea99 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -18,8 +18,9 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
@@ -62,6 +63,12 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("params") {
+ ParamsSuite.checkParams(new LogisticRegression)
+ val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0)
+ ParamsSuite.checkParams(model)
+ }
+
test("logistic regression: default params") {
val lr = new LogisticRegression
assert(lr.getLabelCol === "label")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 1d04ccb509..75cf5bd4ea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -19,15 +19,18 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.Metadata
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -52,6 +55,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
dataset = sqlContext.createDataFrame(rdd)
}
+ test("params") {
+ ParamsSuite.checkParams(new OneVsRest)
+ val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0)
+ val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel))
+ ParamsSuite.checkParams(model)
+ }
+
test("one-vs-rest: default params") {
val numClasses = 3
val ova = new OneVsRest()
@@ -102,6 +112,26 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = ovr.fit(dataset).transform(dataset)
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}
+
+ test("OneVsRest.copy and OneVsRestModel.copy") {
+ val lr = new LogisticRegression()
+ .setMaxIter(1)
+
+ val ovr = new OneVsRest()
+ withClue("copy with classifier unset should work") {
+ ovr.copy(ParamMap(lr.maxIter -> 10))
+ }
+ ovr.setClassifier(lr)
+ val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10))
+ require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects")
+ require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10,
+ "copy should handle extra classifier params")
+
+ val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1))
+ ovrModel.models.foreach { case m: LogisticRegressionModel =>
+ require(m.getThreshold === 0.1, "copy should handle extra model params")
+ }
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index eee9355a67..1b6b69c7dc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -27,7 +29,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* Test suite for [[RandomForestClassifier]].
*/
@@ -62,6 +63,13 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
}
+ test("params") {
+ ParamsSuite.checkParams(new RandomForestClassifier)
+ val model = new RandomForestClassificationModel("rfc",
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
+ ParamsSuite.checkParams(model)
+ }
+
test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val rf = new RandomForestClassifier()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
new file mode 100644
index 0000000000..def869fe66
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+
+class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
+
+ test("params") {
+ ParamsSuite.checkParams(new BinaryClassificationEvaluator)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index 36a1ac6b79..aa722da323 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -18,12 +18,17 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new RegressionEvaluator)
+ }
+
test("Regression Evaluator: default params") {
/**
* Here is the instruction describing how to export the test data into CSV format
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 7953bd0417..2086043983 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -30,6 +31,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
}
+ test("params") {
+ ParamsSuite.checkParams(new Binarizer)
+ }
+
test("Binarize continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 507a8a7db2..ec85e0d151 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import scala.util.Random
import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -27,6 +28,10 @@ import org.apache.spark.sql.{DataFrame, Row}
class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new Bucketizer)
+ }
+
test("Bucket continuous features, without -inf,inf") {
// Check a set of valid feature values.
val splits = Array(-0.5, 0.0, 0.5)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 7b2d70e644..4157b84b29 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -28,8 +28,7 @@ import org.apache.spark.util.Utils
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
- val hashingTF = new HashingTF
- ParamsSuite.checkParams(hashingTF, 3)
+ ParamsSuite.checkParams(new HashingTF)
}
test("hashingTF") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index d83772e8be..08f80af034 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -38,6 +40,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("params") {
+ ParamsSuite.checkParams(new IDF)
+ val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0)))
+ ParamsSuite.checkParams(model)
+ }
+
test("compute IDF with default parameter") {
val numOfFeatures = 4
val data = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 2e5036a844..65846a846b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
@@ -36,6 +37,10 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
indexer.transform(df)
}
+ test("params") {
+ ParamsSuite.checkParams(new OneHotEncoder)
+ }
+
test("OneHotEncoder dropLast = false") {
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index feca866cd7..29eebd8960 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.ml.param.ParamsSuite
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkFunSuite
@@ -27,6 +28,10 @@ import org.apache.spark.sql.Row
class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new PolynomialExpansion)
+ }
+
test("Polynomial expansion with default parameter") {
val data = Array(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 5f557e16e5..99f82bea42 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -19,10 +19,17 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new StringIndexer)
+ val model = new StringIndexerModel("indexer", Array("a", "b"))
+ ParamsSuite.checkParams(model)
+ }
+
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index ac279cb321..e5fd21c3f6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -20,15 +20,27 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
+class TokenizerSuite extends SparkFunSuite {
+
+ test("params") {
+ ParamsSuite.checkParams(new Tokenizer)
+ }
+}
+
class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._
+ test("params") {
+ ParamsSuite.checkParams(new RegexTokenizer)
+ }
+
test("RegexTokenizer") {
val tokenizer0 = new RegexTokenizer()
.setGaps(false)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index 489abb5af7..bb4d5b983e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
@@ -26,6 +27,10 @@ import org.apache.spark.sql.functions.col
class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new VectorAssembler)
+ }
+
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 06affc7305..8c85c96d5c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -21,6 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
@@ -91,6 +92,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
private def getIndexer: VectorIndexer =
new VectorIndexer().setInputCol("features").setOutputCol("indexed")
+ test("params") {
+ ParamsSuite.checkParams(new VectorIndexer)
+ val model = new VectorIndexerModel("indexer", 1, Map.empty)
+ ParamsSuite.checkParams(model)
+ }
+
test("Cannot fit an empty DataFrame") {
val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
val vectorIndexer = getIndexer
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 94ebc3aebf..aa6ce533fd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -18,13 +18,21 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new Word2Vec)
+ val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f))))
+ ParamsSuite.checkParams(model)
+ }
+
test("Word2Vec") {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 96094d7a09..050d4170ea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -205,19 +205,27 @@ class ParamsSuite extends SparkFunSuite {
object ParamsSuite extends SparkFunSuite {
/**
- * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
- * by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as
- * the param method name.
+ * Checks common requirements for [[Params.params]]:
+ * - params are ordered by names
+ * - param parent has the same UID as the object's UID
+ * - param name is the same as the param method name
+ * - obj.copy should return the same type as the obj
*/
- def checkParams(obj: Params, expectedNumParams: Int): Unit = {
+ def checkParams(obj: Params): Unit = {
+ val clazz = obj.getClass
+
val params = obj.params
- require(params.length === expectedNumParams,
- s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.")
val paramNames = params.map(_.name)
- require(paramNames === paramNames.sorted)
+ require(paramNames === paramNames.sorted, "params must be ordered by names")
params.foreach { p =>
assert(p.parent === obj.uid)
assert(obj.getParam(p.name) === p)
+ // TODO: Check that setters return self, which needs special handling for generic types.
}
+
+ val copyMethod = clazz.getMethod("copy", classOf[ParamMap])
+ val copyReturnType = copyMethod.getReturnType
+ require(copyReturnType === obj.getClass,
+ s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index a9e78366ad..2759248344 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -38,7 +38,5 @@ class TestParams(override val uid: String) extends Params with HasMaxIter with H
require(isDefined(inputCol))
}
- override def copy(extra: ParamMap): TestParams = {
- super.copy(extra).asInstanceOf[TestParams]
- }
+ override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
index eb5408d3fe..b3af81a3c6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
@@ -18,13 +18,15 @@
package org.apache.spark.ml.param.shared
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.param.Params
+import org.apache.spark.ml.param.{ParamMap, Params}
class SharedParamsSuite extends SparkFunSuite {
test("outputCol") {
- class Obj(override val uid: String) extends Params with HasOutputCol
+ class Obj(override val uid: String) extends Params with HasOutputCol {
+ override def copy(extra: ParamMap): Obj = defaultCopy(extra)
+ }
val obj = new Obj("obj")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 9b3619f004..36af4b34a9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
-
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator}
@@ -98,6 +97,8 @@ object CrossValidatorSuite {
override def transformSchema(schema: StructType): StructType = {
throw new UnsupportedOperationException
}
+
+ override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
}
class MyEvaluator extends Evaluator {
@@ -107,5 +108,7 @@ object CrossValidatorSuite {
}
override val uid: String = "eval"
+
+ override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
}
}