aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala15
-rw-r--r--python/pyspark/ml/feature.py77
3 files changed, 92 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index b28c88aaae..e52d797293 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -48,7 +48,7 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
* otherwise, values outside the splits specified will be treated as errors.
* @group param
*/
- val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
+ val splits: DoubleArrayParam = new DoubleArrayParam(this, "splits",
"Split points for mapping continuous features into buckets. With n+1 splits, there are n " +
"buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " +
"bucket, which also includes y. The splits should be strictly increasing. " +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 7ebbf106ee..5a7ec29aac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -219,7 +219,7 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}
-/** Specialized version of [[Param[Array[T]]]] for Java. */
+/** Specialized version of [[Param[Array[String]]]] for Java. */
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
extends Param[Array[String]](parent, name, doc, isValid) {
@@ -232,6 +232,19 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
}
+/** Specialized version of [[Param[Array[Double]]]] for Java. */
+class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean)
+ extends Param[Array[Double]](parent, name, doc, isValid) {
+
+ def this(parent: Params, name: String, doc: String) =
+ this(parent, name, doc, ParamValidators.alwaysTrue)
+
+ override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value)
+
+ /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
+ def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray)
+}
+
/**
* A param amd its value.
*/
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index f35bc1463d..30e1fd4922 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -84,6 +84,83 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
+class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol):
+ """
+ Maps a column of continuous features to a column of feature buckets.
+
+ >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
+ >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
+ ... inputCol="values", outputCol="buckets")
+ >>> bucketed = bucketizer.transform(df).collect()
+ >>> bucketed[0].buckets
+ 0.0
+ >>> bucketed[1].buckets
+ 0.0
+ >>> bucketed[2].buckets
+ 1.0
+ >>> bucketed[3].buckets
+ 2.0
+ >>> bucketizer.setParams(outputCol="b").transform(df).head().b
+ 0.0
+ """
+
+ _java_class = "org.apache.spark.ml.feature.Bucketizer"
+ # a placeholder to make it appear in the generated doc
+ splits = \
+ Param(Params._dummy(), "splits",
+ "Split points for mapping continuous features into buckets. With n+1 splits, " +
+ "there are n buckets. A bucket defined by splits x,y holds values in the " +
+ "range [x,y) except the last bucket, which also includes y. The splits " +
+ "should be strictly increasing. Values at -inf, inf must be explicitly " +
+ "provided to cover all Double values; otherwise, values outside the splits " +
+ "specified will be treated as errors.")
+
+ @keyword_only
+ def __init__(self, splits=None, inputCol=None, outputCol=None):
+ """
+ __init__(self, splits=None, inputCol=None, outputCol=None)
+ """
+ super(Bucketizer, self).__init__()
+ #: param for Splitting points for mapping continuous features into buckets. With n+1 splits,
+ # there are n buckets. A bucket defined by splits x,y holds values in the range [x,y)
+ # except the last bucket, which also includes y. The splits should be strictly increasing.
+ # Values at -inf, inf must be explicitly provided to cover all Double values; otherwise,
+ # values outside the splits specified will be treated as errors.
+ self.splits = \
+ Param(self, "splits",
+ "Split points for mapping continuous features into buckets. With n+1 splits, " +
+ "there are n buckets. A bucket defined by splits x,y holds values in the " +
+ "range [x,y) except the last bucket, which also includes y. The splits " +
+ "should be strictly increasing. Values at -inf, inf must be explicitly " +
+ "provided to cover all Double values; otherwise, values outside the splits " +
+ "specified will be treated as errors.")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, splits=None, inputCol=None, outputCol=None):
+ """
+ setParams(self, splits=None, inputCol=None, outputCol=None)
+ Sets params for this Bucketizer.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def setSplits(self, value):
+ """
+ Sets the value of :py:attr:`splits`.
+ """
+ self.paramMap[self.splits] = value
+ return self
+
+ def getSplits(self):
+ """
+ Gets the value of threshold or its default value.
+ """
+ return self.getOrDefault(self.splits)
+
+
+@inherit_doc
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
"""
Maps a sequence of terms to their term frequencies using the