aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-05-07 10:25:41 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-07 10:25:41 -0700
commit9e2ffb13287e6efe256b8d23a4654e4cc305e20b (patch)
tree79a13615578199c2907d371c965ef031307c47b9 /mllib
parented9be06a4797bbb678355b361054c8872ac20b75 (diff)
downloadspark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.tar.gz
spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.tar.bz2
spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.zip
[SPARK-7388] [SPARK-7383] wrapper for VectorAssembler in Python
The wrapper required the implementation of the `ArrayParam`, because `Array[T]` is hard to obtain from Python. `ArrayParam` has an extra function called `wCast` which is an internal function to obtain `Array[T]` from `Seq[T]` Author: Burak Yavuz <brkyvz@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #5930 from brkyvz/ml-feat and squashes the following commits: 73e745f [Burak Yavuz] Merge pull request #3 from mengxr/SPARK-7388 c221db9 [Xiangrui Meng] overload StringArrayParam.w c81072d [Burak Yavuz] addressed comments 99c2ebf [Burak Yavuz] add to python_shared_params 39ecb07 [Burak Yavuz] fix scalastyle 7f7ea2a [Burak Yavuz] [SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala27
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala2
4 files changed, 27 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 8f2e62a8e2..b5a69cee6d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
/**
* :: AlphaComponent ::
- * A feature transformer than merge multiple columns into a vector column.
+ * A feature transformer that merges multiple columns into a vector column.
*/
@AlphaComponent
class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 6d09962fe6..0e1b60d172 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -22,6 +22,7 @@ import java.util.NoSuchElementException
import scala.annotation.varargs
import scala.collection.mutable
+import scala.collection.JavaConverters._
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.util.Identifiable
@@ -218,6 +219,19 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}
+/** Specialized version of [[Param[Array[T]]]] for Java. */
+class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
+ extends Param[Array[String]](parent, name, doc, isValid) {
+
+ def this(parent: Params, name: String, doc: String) =
+ this(parent, name, doc, ParamValidators.alwaysTrue)
+
+ override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
+
+ /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
+ def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
+}
+
/**
* A param amd its value.
*/
@@ -310,9 +324,7 @@ trait Params extends Identifiable with Serializable {
* Sets a parameter in the embedded param map.
*/
protected final def set[T](param: Param[T], value: T): this.type = {
- shouldOwn(param)
- paramMap.put(param.asInstanceOf[Param[Any]], value)
- this
+ set(param -> value)
}
/**
@@ -323,6 +335,15 @@ trait Params extends Identifiable with Serializable {
}
/**
+ * Sets a parameter in the embedded param map.
+ */
+ protected final def set(paramPair: ParamPair[_]): this.type = {
+ shouldOwn(paramPair.param)
+ paramMap.put(paramPair)
+ this
+ }
+
+ /**
* Optionally returns the user-supplied value of a param.
*/
final def get[T](param: Param[T]): Option[T] = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 0e1ff97a8b..5085b798da 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -85,6 +85,7 @@ private[shared] object SharedParamsCodeGen {
case _ if c == classOf[Float] => "FloatParam"
case _ if c == classOf[Double] => "DoubleParam"
case _ if c == classOf[Boolean] => "BooleanParam"
+ case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
case _ => s"Param[${getTypeString(c)}]"
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 87f86807c3..7525d37007 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
* Param for input column names.
* @group param
*/
- final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names")
+ final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names")
/** @group getParam */
final def getInputCols: Array[String] = $(inputCols)