aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorOliver Pierson <ocp@gatech.edu>2016-04-11 12:02:48 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-11 12:02:48 -0700
commit89a41c5b7a3f727b44a7f615a1352ca006d12f73 (patch)
tree1c59e13c4fe03bbb0c5717f6c08311a2d2648da2 /mllib
parent2dacc81ec31233e558855a26340ad4662d470387 (diff)
downloadspark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.tar.gz
spark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.tar.bz2
spark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.zip
[SPARK-13600][MLLIB] Use approxQuantile from DataFrame stats in QuantileDiscretizer
## What changes were proposed in this pull request? QuantileDiscretizer can return an unexpected number of buckets in certain cases. This PR proposes to fix this issue and also refactor QuantileDiscretizer to use approxQuantiles from DataFrame stats functions. ## How was this patch tested? QuantileDiscretizerSuite unit tests (some existing tests will change or even be removed in this PR) Author: Oliver Pierson <ocp@gatech.edu> Closes #11553 from oliverpierson/SPARK-13600.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala119
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala115
2 files changed, 65 insertions, 169 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index efe8b93d82..5c7993af64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol with HasSeed {
/**
- * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
+ * Number of buckets (quantiles, or categories) into which data points are grouped. Must
* be >= 2.
* default: 2
* @group param
@@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params
/** @group getParam */
def getNumBuckets: Int = getOrDefault(numBuckets)
+
+ /**
+ * Relative error (see documentation for
+ * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description)
+ * Must be a number in [0, 1].
+ * default: 0.001
+ * @group param
+ */
+ val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " +
+ "for approxQuantile",
+ ParamValidators.inRange(0.0, 1.0))
+ setDefault(relativeError -> 0.001)
+
+ /** @group getParam */
+ def getRelativeError: Double = getOrDefault(relativeError)
}
/**
@@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The bin ranges are chosen by taking a sample of the data and dividing it
* into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity,
- * covering all real values. This attempts to find numBuckets partitions based on a sample of data,
- * but it may find fewer depending on the data sample values.
+ * covering all real values.
*/
@Experimental
final class QuantileDiscretizer(override val uid: String)
@@ -66,6 +80,9 @@ final class QuantileDiscretizer(override val uid: String)
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
/** @group setParam */
+ def setRelativeError(value: Double): this.type = set(relativeError, value)
+
+ /** @group setParam */
def setNumBuckets(value: Int): this.type = set(numBuckets, value)
/** @group setParam */
@@ -89,11 +106,11 @@ final class QuantileDiscretizer(override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): Bucketizer = {
- val samples = QuantileDiscretizer
- .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
- .map { case Row(feature: Double) => feature }
- val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
- val splits = QuantileDiscretizer.getSplits(candidates)
+ val splits = dataset.stat.approxQuantile($(inputCol),
+ (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+ splits(0) = Double.NegativeInfinity
+ splits(splits.length - 1) = Double.PositiveInfinity
+
val bucketizer = new Bucketizer(uid).setSplits(splits)
copyValues(bucketizer.setParent(this))
}
@@ -104,92 +121,6 @@ final class QuantileDiscretizer(override val uid: String)
@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
- /**
- * Minimum number of samples required for finding splits, regardless of number of bins. If
- * the dataset has fewer rows than this value, the entire dataset will be used.
- */
- private[spark] val minSamplesRequired: Int = 10000
-
- /**
- * Sampling from the given dataset to collect quantile statistics.
- */
- private[feature]
- def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = {
- val totalSamples = dataset.count()
- require(totalSamples > 0,
- "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
- val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
- val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
- dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
- .collect()
- }
-
- /**
- * Compute split points with respect to the sample distribution.
- */
- private[feature]
- def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = {
- val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
- m + ((x, m.getOrElse(x, 0) + 1))
- }
- val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1))
- val possibleSplits = valueCounts.length - 1
- if (possibleSplits <= numSplits) {
- valueCounts.dropRight(1).map(_._1)
- } else {
- val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1))
- val splitsBuilder = mutable.ArrayBuilder.make[Double]
- var index = 1
- // currentCount: sum of counts of values that have been visited
- var currentCount = valueCounts(0)._2
- // targetCount: target value for `currentCount`. If `currentCount` is closest value to
- // `targetCount`, then current value is a split threshold. After finding a split threshold,
- // `targetCount` is added by stride.
- var targetCount = stride
- while (index < valueCounts.length) {
- val previousCount = currentCount
- currentCount += valueCounts(index)._2
- val previousGap = math.abs(previousCount - targetCount)
- val currentGap = math.abs(currentCount - targetCount)
- // If adding count of current value to currentCount makes the gap between currentCount and
- // targetCount smaller, previous value is a split threshold.
- if (previousGap < currentGap) {
- splitsBuilder += valueCounts(index - 1)._1
- targetCount += stride
- }
- index += 1
- }
- splitsBuilder.result()
- }
- }
-
- /**
- * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as
- * needed, and adding a default split value of 0 if no good candidates are found.
- */
- private[feature] def getSplits(candidates: Array[Double]): Array[Double] = {
- val effectiveValues = if (candidates.nonEmpty) {
- if (candidates.head == Double.NegativeInfinity
- && candidates.last == Double.PositiveInfinity) {
- candidates.drop(1).dropRight(1)
- } else if (candidates.head == Double.NegativeInfinity) {
- candidates.drop(1)
- } else if (candidates.last == Double.PositiveInfinity) {
- candidates.dropRight(1)
- } else {
- candidates
- }
- } else {
- candidates
- }
-
- if (effectiveValues.isEmpty) {
- Array(Double.NegativeInfinity, 0, Double.PositiveInfinity)
- } else {
- Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity)
- }
- }
-
@Since("1.6.0")
override def load(path: String): QuantileDiscretizer = super.load(path)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 25fabf64d5..8895d630a0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -17,78 +17,60 @@
package org.apache.spark.ml.feature
-import org.apache.spark.{SparkContext, SparkFunSuite}
-import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.functions.udf
class QuantileDiscretizerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- import org.apache.spark.ml.feature.QuantileDiscretizerSuite._
-
- test("Test quantile discretizer") {
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 10,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
-
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 4,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
-
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 3,
- Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2),
- Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity"))
+ test("Test observed number of buckets and their sizes match expected values") {
+ val sqlCtx = SQLContext.getOrCreate(sc)
+ import sqlCtx.implicits._
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 2,
- Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1),
- Array("-Infinity, 2.0", "2.0, Infinity"))
+ val datasetSize = 100000
+ val numBuckets = 5
+ val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(numBuckets)
+ val result = discretizer.fit(df).transform(df)
- }
+ val observedNumBuckets = result.select("result").distinct.count
+ assert(observedNumBuckets === numBuckets,
+ "Observed number of buckets does not equal expected number of buckets.")
- test("Test getting splits") {
- val splitTestPoints = Array(
- Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.NegativeInfinity, Double.PositiveInfinity)
- -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity),
- Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity)
- )
- for ((ori, res) <- splitTestPoints) {
- assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
+ val relativeError = discretizer.getRelativeError
+ val isGoodBucket = udf {
+ (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize)
}
+ val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count
+ assert(numGoodBuckets === numBuckets,
+ "Bucket sizes are not within expected relative error tolerance.")
}
- test("Test splits on dataset larger than minSamplesRequired") {
+ test("Test transform method on unseen data") {
val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._
- val datasetSize = QuantileDiscretizer.minSamplesRequired + 1
- val numBuckets = 5
- val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
+ val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
+ val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input")
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
- .setNumBuckets(numBuckets)
- .setSeed(1)
+ .setNumBuckets(5)
- val result = discretizer.fit(df).transform(df)
- val observedNumBuckets = result.select("result").distinct.count
+ val result = discretizer.fit(trainDF).transform(testDF)
+ val firstBucketSize = result.filter(result("result") === 0.0).count
+ val lastBucketSize = result.filter(result("result") === 4.0).count
- assert(observedNumBuckets === numBuckets,
- "Observed number of buckets does not equal expected number of buckets.")
+ assert(firstBucketSize === 30L,
+ s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.")
+ assert(lastBucketSize === 31L,
+ s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.")
}
test("read/write") {
@@ -98,34 +80,17 @@ class QuantileDiscretizerSuite
.setNumBuckets(6)
testDefaultReadWrite(t)
}
-}
-
-private object QuantileDiscretizerSuite extends SparkFunSuite {
- def checkDiscretizedData(
- sc: SparkContext,
- data: Array[Double],
- numBucket: Int,
- expectedResult: Array[Double],
- expectedAttrs: Array[String]): Unit = {
+ test("Verify resulting model has parent") {
val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._
- val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
- val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
- .setNumBuckets(numBucket).setSeed(1)
+ val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(5)
val model = discretizer.fit(df)
assert(model.hasParent)
- val result = model.transform(df)
-
- val transformedFeatures = result.select("result").collect()
- .map { case Row(transformedFeature: Double) => transformedFeature }
- val transformedAttrs = Attribute.fromStructField(result.schema("result"))
- .asInstanceOf[NominalAttribute].values.get
-
- assert(transformedFeatures === expectedResult,
- "Transformed features do not equal expected features.")
- assert(transformedAttrs === expectedAttrs,
- "Transformed attributes do not equal expected attributes.")
}
}