aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala119
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala115
2 files changed, 65 insertions, 169 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index efe8b93d82..5c7993af64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol with HasSeed {
/**
- * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
+ * Number of buckets (quantiles, or categories) into which data points are grouped. Must
* be >= 2.
* default: 2
* @group param
@@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params
/** @group getParam */
def getNumBuckets: Int = getOrDefault(numBuckets)
+
+ /**
+ * Relative error (see documentation for
+ * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description)
+ * Must be a number in [0, 1].
+ * default: 0.001
+ * @group param
+ */
+ val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " +
+ "for approxQuantile",
+ ParamValidators.inRange(0.0, 1.0))
+ setDefault(relativeError -> 0.001)
+
+ /** @group getParam */
+ def getRelativeError: Double = getOrDefault(relativeError)
}
/**
@@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The bin ranges are chosen by taking a sample of the data and dividing it
* into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity,
- * covering all real values. This attempts to find numBuckets partitions based on a sample of data,
- * but it may find fewer depending on the data sample values.
+ * covering all real values.
*/
@Experimental
final class QuantileDiscretizer(override val uid: String)
@@ -66,6 +80,9 @@ final class QuantileDiscretizer(override val uid: String)
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
/** @group setParam */
+ def setRelativeError(value: Double): this.type = set(relativeError, value)
+
+ /** @group setParam */
def setNumBuckets(value: Int): this.type = set(numBuckets, value)
/** @group setParam */
@@ -89,11 +106,11 @@ final class QuantileDiscretizer(override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): Bucketizer = {
- val samples = QuantileDiscretizer
- .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
- .map { case Row(feature: Double) => feature }
- val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
- val splits = QuantileDiscretizer.getSplits(candidates)
+ val splits = dataset.stat.approxQuantile($(inputCol),
+ (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+ splits(0) = Double.NegativeInfinity
+ splits(splits.length - 1) = Double.PositiveInfinity
+
val bucketizer = new Bucketizer(uid).setSplits(splits)
copyValues(bucketizer.setParent(this))
}
@@ -104,92 +121,6 @@ final class QuantileDiscretizer(override val uid: String)
@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
- /**
- * Minimum number of samples required for finding splits, regardless of number of bins. If
- * the dataset has fewer rows than this value, the entire dataset will be used.
- */
- private[spark] val minSamplesRequired: Int = 10000
-
- /**
- * Sampling from the given dataset to collect quantile statistics.
- */
- private[feature]
- def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = {
- val totalSamples = dataset.count()
- require(totalSamples > 0,
- "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
- val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
- val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
- dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
- .collect()
- }
-
- /**
- * Compute split points with respect to the sample distribution.
- */
- private[feature]
- def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = {
- val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
- m + ((x, m.getOrElse(x, 0) + 1))
- }
- val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1))
- val possibleSplits = valueCounts.length - 1
- if (possibleSplits <= numSplits) {
- valueCounts.dropRight(1).map(_._1)
- } else {
- val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1))
- val splitsBuilder = mutable.ArrayBuilder.make[Double]
- var index = 1
- // currentCount: sum of counts of values that have been visited
- var currentCount = valueCounts(0)._2
- // targetCount: target value for `currentCount`. If `currentCount` is closest value to
- // `targetCount`, then current value is a split threshold. After finding a split threshold,
- // `targetCount` is added by stride.
- var targetCount = stride
- while (index < valueCounts.length) {
- val previousCount = currentCount
- currentCount += valueCounts(index)._2
- val previousGap = math.abs(previousCount - targetCount)
- val currentGap = math.abs(currentCount - targetCount)
- // If adding count of current value to currentCount makes the gap between currentCount and
- // targetCount smaller, previous value is a split threshold.
- if (previousGap < currentGap) {
- splitsBuilder += valueCounts(index - 1)._1
- targetCount += stride
- }
- index += 1
- }
- splitsBuilder.result()
- }
- }
-
- /**
- * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as
- * needed, and adding a default split value of 0 if no good candidates are found.
- */
- private[feature] def getSplits(candidates: Array[Double]): Array[Double] = {
- val effectiveValues = if (candidates.nonEmpty) {
- if (candidates.head == Double.NegativeInfinity
- && candidates.last == Double.PositiveInfinity) {
- candidates.drop(1).dropRight(1)
- } else if (candidates.head == Double.NegativeInfinity) {
- candidates.drop(1)
- } else if (candidates.last == Double.PositiveInfinity) {
- candidates.dropRight(1)
- } else {
- candidates
- }
- } else {
- candidates
- }
-
- if (effectiveValues.isEmpty) {
- Array(Double.NegativeInfinity, 0, Double.PositiveInfinity)
- } else {
- Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity)
- }
- }
-
@Since("1.6.0")
override def load(path: String): QuantileDiscretizer = super.load(path)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 25fabf64d5..8895d630a0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -17,78 +17,60 @@
package org.apache.spark.ml.feature
-import org.apache.spark.{SparkContext, SparkFunSuite}
-import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.functions.udf
class QuantileDiscretizerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- import org.apache.spark.ml.feature.QuantileDiscretizerSuite._
-
- test("Test quantile discretizer") {
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 10,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
-
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 4,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
-
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 3,
- Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2),
- Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity"))
+ test("Test observed number of buckets and their sizes match expected values") {
+ val sqlCtx = SQLContext.getOrCreate(sc)
+ import sqlCtx.implicits._
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 2,
- Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1),
- Array("-Infinity, 2.0", "2.0, Infinity"))
+ val datasetSize = 100000
+ val numBuckets = 5
+ val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(numBuckets)
+ val result = discretizer.fit(df).transform(df)
- }
+ val observedNumBuckets = result.select("result").distinct.count
+ assert(observedNumBuckets === numBuckets,
+ "Observed number of buckets does not equal expected number of buckets.")
- test("Test getting splits") {
- val splitTestPoints = Array(
- Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.NegativeInfinity, Double.PositiveInfinity)
- -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity),
- Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity)
- )
- for ((ori, res) <- splitTestPoints) {
- assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
+ val relativeError = discretizer.getRelativeError
+ val isGoodBucket = udf {
+ (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize)
}
+ val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count
+ assert(numGoodBuckets === numBuckets,
+ "Bucket sizes are not within expected relative error tolerance.")
}
- test("Test splits on dataset larger than minSamplesRequired") {
+ test("Test transform method on unseen data") {
val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._
- val datasetSize = QuantileDiscretizer.minSamplesRequired + 1
- val numBuckets = 5
- val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
+ val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
+ val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input")
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
- .setNumBuckets(numBuckets)
- .setSeed(1)
+ .setNumBuckets(5)
- val result = discretizer.fit(df).transform(df)
- val observedNumBuckets = result.select("result").distinct.count
+ val result = discretizer.fit(trainDF).transform(testDF)
+ val firstBucketSize = result.filter(result("result") === 0.0).count
+ val lastBucketSize = result.filter(result("result") === 4.0).count
- assert(observedNumBuckets === numBuckets,
- "Observed number of buckets does not equal expected number of buckets.")
+ assert(firstBucketSize === 30L,
+ s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.")
+ assert(lastBucketSize === 31L,
+ s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.")
}
test("read/write") {
@@ -98,34 +80,17 @@ class QuantileDiscretizerSuite
.setNumBuckets(6)
testDefaultReadWrite(t)
}
-}
-
-private object QuantileDiscretizerSuite extends SparkFunSuite {
- def checkDiscretizedData(
- sc: SparkContext,
- data: Array[Double],
- numBucket: Int,
- expectedResult: Array[Double],
- expectedAttrs: Array[String]): Unit = {
+ test("Verify resulting model has parent") {
val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._
- val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
- val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
- .setNumBuckets(numBucket).setSeed(1)
+ val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(5)
val model = discretizer.fit(df)
assert(model.hasParent)
- val result = model.transform(df)
-
- val transformedFeatures = result.select("result").collect()
- .map { case Row(transformedFeature: Double) => transformedFeature }
- val transformedAttrs = Attribute.fromStructField(result.schema("result"))
- .asInstanceOf[NominalAttribute].values.get
-
- assert(transformedFeatures === expectedResult,
- "Transformed features do not equal expected features.")
- assert(transformedAttrs === expectedAttrs,
- "Transformed attributes do not equal expected attributes.")
}
}