diff options
author | Yu ISHIKAWA <yuu.ishikawa@gmail.com> | 2016-02-11 15:05:34 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-02-11 15:05:34 -0800 |
commit | 574571c87098795a2206a113ee9ed4bafba8f00f (patch) | |
tree | 6552478f8e19aecba3fe13b026ec85ddafaa6966 | |
parent | efb65e09bcfa4542348f5cd37fe5c14047b862e5 (diff) | |
download | spark-574571c87098795a2206a113ee9ed4bafba8f00f.tar.gz spark-574571c87098795a2206a113ee9ed4bafba8f00f.tar.bz2 spark-574571c87098795a2206a113ee9ed4bafba8f00f.zip |
[SPARK-11515][ML] QuantileDiscretizer should take random seed
cc jkbradley
Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>
Closes #9535 from yu-iskw/SPARK-11515.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala | 15 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala | 2 |
2 files changed, 11 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 8fd0ce2f2e..2a294d3881 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param.{IntParam, _} -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.{DoubleType, StructType} @@ -33,7 +33,8 @@ import org.apache.spark.util.random.XORShiftRandom /** * Params for [[QuantileDiscretizer]]. */ -private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait QuantileDiscretizerBase extends Params + with HasInputCol with HasOutputCol with HasSeed { /** * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must @@ -73,6 +74,9 @@ final class QuantileDiscretizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + override def transformSchema(schema: StructType): StructType = { validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) @@ -85,7 +89,8 @@ final class QuantileDiscretizer(override val uid: String) } override def fit(dataset: DataFrame): Bucketizer = { - val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets)) + val samples = QuantileDiscretizer + .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed)) .map { case Row(feature: Double) => feature } val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) val splits = QuantileDiscretizer.getSplits(candidates) @@ -101,13 +106,13 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi /** * Sampling from the given dataset to collect quantile statistics. */ - private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = { + private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = { val totalSamples = dataset.count() require(totalSamples > 0, "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") val requiredSamples = math.max(numBins * numBins, 10000) val fraction = math.min(requiredSamples / dataset.count(), 1.0) - dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() + dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 722f1abde4..4fde42972f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -93,7 +93,7 @@ private object QuantileDiscretizerSuite extends SparkFunSuite { val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") - .setNumBuckets(numBucket) + .setNumBuckets(numBucket).setSeed(1) val result = discretizer.fit(df).transform(df) val transformedFeatures = result.select("result").collect() |