aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYu ISHIKAWA <yuu.ishikawa@gmail.com>2016-02-11 15:05:34 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-11 15:05:34 -0800
commit574571c87098795a2206a113ee9ed4bafba8f00f (patch)
tree6552478f8e19aecba3fe13b026ec85ddafaa6966 /mllib
parentefb65e09bcfa4542348f5cd37fe5c14047b862e5 (diff)
downloadspark-574571c87098795a2206a113ee9ed4bafba8f00f.tar.gz
spark-574571c87098795a2206a113ee9ed4bafba8f00f.tar.bz2
spark-574571c87098795a2206a113ee9ed4bafba8f00f.zip
[SPARK-11515][ML] QuantileDiscretizer should take random seed
cc jkbradley Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com> Closes #9535 from yu-iskw/SPARK-11515.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala2
2 files changed, 11 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 8fd0ce2f2e..2a294d3881 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param.{IntParam, _}
-import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{DoubleType, StructType}
@@ -33,7 +33,8 @@ import org.apache.spark.util.random.XORShiftRandom
/**
* Params for [[QuantileDiscretizer]].
*/
-private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol {
+private[feature] trait QuantileDiscretizerBase extends Params
+ with HasInputCol with HasOutputCol with HasSeed {
/**
* Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
@@ -73,6 +74,9 @@ final class QuantileDiscretizer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
@@ -85,7 +89,8 @@ final class QuantileDiscretizer(override val uid: String)
}
override def fit(dataset: DataFrame): Bucketizer = {
- val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets))
+ val samples = QuantileDiscretizer
+ .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
.map { case Row(feature: Double) => feature }
val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
val splits = QuantileDiscretizer.getSplits(candidates)
@@ -101,13 +106,13 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
/**
* Sampling from the given dataset to collect quantile statistics.
*/
- private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = {
+ private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = {
val totalSamples = dataset.count()
require(totalSamples > 0,
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
val requiredSamples = math.max(numBins * numBins, 10000)
val fraction = math.min(requiredSamples / dataset.count(), 1.0)
- dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
+ dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 722f1abde4..4fde42972f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -93,7 +93,7 @@ private object QuantileDiscretizerSuite extends SparkFunSuite {
val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
- .setNumBuckets(numBucket)
+ .setNumBuckets(numBucket).setSeed(1)
val result = discretizer.fit(df).transform(df)
val transformedFeatures = result.select("result").collect()