aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-05-07 01:28:44 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-07 01:28:44 -0700
commit4f87e9562aa0dfe5467d7fbaba9278213106377c (patch)
treecdf3bbd2d617354ed589559d3eafd69a4616582e
parent8b6b46e4ff5f19fb7befecaaa0eda63bf29a0e2c (diff)
downloadspark-4f87e9562aa0dfe5467d7fbaba9278213106377c.tar.gz
spark-4f87e9562aa0dfe5467d7fbaba9278213106377c.tar.bz2
spark-4f87e9562aa0dfe5467d7fbaba9278213106377c.zip
[SPARK-7429] [ML] Params cleanups
Params.setDefault taking a set of ParamPairs should be annotated with varargs. I thought it would not work before, but it apparently does. CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #5960 from jkbradley/params-cleanups and squashes the following commits: 118b158 [Joseph K. Bradley] Params.setDefault taking a set of ParamPairs should be annotated with varargs. I thought it would not work before, but it apparently does. CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala3
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java1
3 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 51ce19d29c..6d09962fe6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -366,13 +366,11 @@ trait Params extends Identifiable with Serializable {
/**
* Sets default values for a list of params.
*
- * Note: Java developers should use the single-parameter [[setDefault()]].
- * Annotating this with varargs causes compilation failures.
- *
* @param paramPairs a list of param pairs that specify params and their default values to set
* respectively. Make sure that the params are initialized before this method
* gets called.
*/
+ @varargs
protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
setDefault(p.param.asInstanceOf[Param[Any]], p.value)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 9208127eb1..ac0d1fed84 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -105,7 +105,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
- transformSchema(dataset.schema, logging = true)
+ transformSchema(schema, logging = true)
val sqlCtx = dataset.sqlContext
val est = $(estimator)
val eval = $(evaluator)
@@ -159,6 +159,7 @@ class CrossValidatorModel private[ml] (
}
override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 8abe575610..532eca4791 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -59,5 +59,6 @@ public class JavaTestParams extends JavaParams {
ParamValidators.inArray(validStrings));
setDefault(myIntParam, 1);
setDefault(myDoubleParam, 0.5);
+ setDefault(myIntParam.w(1), myDoubleParam.w(0.5));
}
}