aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-20 17:26:26 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-20 17:26:26 -0700
commitc330e52dae6a3ec7e67ca82e2c2f4ea873976458 (patch)
treea6e98424c41b264292f6b8f7b777c7dc8e0547f3 /mllib
parentf2faa7af30662e3bdf15780f8719c71108f8e30b (diff)
downloadspark-c330e52dae6a3ec7e67ca82e2c2f4ea873976458.tar.gz
spark-c330e52dae6a3ec7e67ca82e2c2f4ea873976458.tar.bz2
spark-c330e52dae6a3ec7e67ca82e2c2f4ea873976458.zip
[SPARK-7762] [MLLIB] set default value for outputCol
Set a default value for `outputCol` instead of forcing users to name it. This is useful for intermediate transformers in the pipeline. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6289 from mengxr/SPARK-7762 and squashes the following commits: 54edebc [Xiangrui Meng] merge master bff8667 [Xiangrui Meng] update unit test 171246b [Xiangrui Meng] add unit test for outputCol a4321bd [Xiangrui Meng] set default value for outputCol
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala35
3 files changed, 39 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 8b8cb81373..1ffb5eddc3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -49,7 +49,7 @@ private[shared] object SharedParamsCodeGen {
isValid = "ParamValidators.inRange(0, 1)"),
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
- ParamDesc[String]("outputCol", "output column name"),
+ ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
isValid = "ParamValidators.gtEq(1)"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 3a4976d3dd..ed08417bd4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -185,7 +185,7 @@ private[ml] trait HasInputCols extends Params {
}
/**
- * (private[ml]) Trait for shared param outputCol.
+ * (private[ml]) Trait for shared param outputCol (default: uid + "__output").
*/
private[ml] trait HasOutputCol extends Params {
@@ -195,6 +195,8 @@ private[ml] trait HasOutputCol extends Params {
*/
final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name")
+ setDefault(outputCol, uid + "__output")
+
/** @group getParam */
final def getOutputCol: String = $(outputCol)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
new file mode 100644
index 0000000000..ca18fa1ad3
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param.shared
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.param.Params
+
+class SharedParamsSuite extends FunSuite {
+
+ test("outputCol") {
+
+ class Obj(override val uid: String) extends Params with HasOutputCol
+
+ val obj = new Obj("obj")
+
+ assert(obj.hasDefault(obj.outputCol))
+ assert(obj.getOrDefault(obj.outputCol) === "obj__output")
+ }
+}