aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorOliver Pierson <ocp@gatech.edu>2016-02-25 13:24:46 +0000
committerSean Owen <sowen@cloudera.com>2016-02-25 13:24:46 +0000
commit6f8e835c68dff6fcf97326dc617132a41ff9d043 (patch)
treed0842e5e46ef3e8c7a3bd0f3873a7bd67af34ba1 /mllib
parent3fa6491be66dad690ca5329dd32e7c82037ae8c1 (diff)
downloadspark-6f8e835c68dff6fcf97326dc617132a41ff9d043.tar.gz
spark-6f8e835c68dff6fcf97326dc617132a41ff9d043.tar.bz2
spark-6f8e835c68dff6fcf97326dc617132a41ff9d043.zip
[SPARK-13444][MLLIB] QuantileDiscretizer chooses bad splits on large DataFrames
## What changes were proposed in this pull request? Change line 113 of QuantileDiscretizer.scala to `val requiredSamples = math.max(numBins * numBins, 10000.0)` so that `requiredSamples` is a `Double`. This will fix the division in line 114 which currently results in zero if `requiredSamples < dataset.count` ## How was the this patch tested? Manual tests. I was having a problems using QuantileDiscretizer with my a dataset and after making this change QuantileDiscretizer behaves as expected. Author: Oliver Pierson <ocp@gatech.edu> Author: Oliver Pierson <opierson@umd.edu> Closes #11319 from oliverpierson/SPARK-13444.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala20
2 files changed, 29 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 1f4cca1233..769f4406e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -103,6 +103,13 @@ final class QuantileDiscretizer(override val uid: String)
@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
+
+ /**
+ * Minimum number of samples required for finding splits, regardless of number of bins. If
+ * the dataset has fewer rows than this value, the entire dataset will be used.
+ */
+ private[spark] val minSamplesRequired: Int = 10000
+
/**
* Sampling from the given dataset to collect quantile statistics.
*/
@@ -110,8 +117,8 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
val totalSamples = dataset.count()
require(totalSamples > 0,
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
- val requiredSamples = math.max(numBins * numBins, 10000)
- val fraction = math.min(requiredSamples / dataset.count(), 1.0)
+ val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
+ val fraction = math.min(requiredSamples.toDouble / dataset.count(), 1.0)
dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 6a2c601bbe..25fabf64d5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -71,6 +71,26 @@ class QuantileDiscretizerSuite
}
}
+ test("Test splits on dataset larger than minSamplesRequired") {
+ val sqlCtx = SQLContext.getOrCreate(sc)
+ import sqlCtx.implicits._
+
+ val datasetSize = QuantileDiscretizer.minSamplesRequired + 1
+ val numBuckets = 5
+ val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(numBuckets)
+ .setSeed(1)
+
+ val result = discretizer.fit(df).transform(df)
+ val observedNumBuckets = result.select("result").distinct.count
+
+ assert(observedNumBuckets === numBuckets,
+ "Observed number of buckets does not equal expected number of buckets.")
+ }
+
test("read/write") {
val t = new QuantileDiscretizer()
.setInputCol("myInputCol")