aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala3
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java1
3 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 51ce19d29c..6d09962fe6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -366,13 +366,11 @@ trait Params extends Identifiable with Serializable {
/**
* Sets default values for a list of params.
*
- * Note: Java developers should use the single-parameter [[setDefault()]].
- * Annotating this with varargs causes compilation failures.
- *
* @param paramPairs a list of param pairs that specify params and their default values to set
* respectively. Make sure that the params are initialized before this method
* gets called.
*/
+ @varargs
protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
setDefault(p.param.asInstanceOf[Param[Any]], p.value)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 9208127eb1..ac0d1fed84 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -105,7 +105,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
- transformSchema(dataset.schema, logging = true)
+ transformSchema(schema, logging = true)
val sqlCtx = dataset.sqlContext
val est = $(estimator)
val eval = $(evaluator)
@@ -159,6 +159,7 @@ class CrossValidatorModel private[ml] (
}
override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 8abe575610..532eca4791 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -59,5 +59,6 @@ public class JavaTestParams extends JavaParams {
ParamValidators.inArray(validStrings));
setDefault(myIntParam, 1);
setDefault(myDoubleParam, 0.5);
+ setDefault(myIntParam.w(1), myDoubleParam.w(0.5));
}
}