diff options
author | Zheng RuiFeng <ruifengz@foxmail.com> | 2016-08-04 21:39:45 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-08-04 21:39:45 +0100 |
commit | be8ea4b2f7ddf1196111acb61fe1a79866376003 (patch) | |
tree | 4d5f36a72aef3dc8e01a327be1aeb1cd3fc01172 | |
parent | ac2a26d09e10c3f462ec773c3ebaa6eedae81ac0 (diff) | |
download | spark-be8ea4b2f7ddf1196111acb61fe1a79866376003.tar.gz spark-be8ea4b2f7ddf1196111acb61fe1a79866376003.tar.bz2 spark-be8ea4b2f7ddf1196111acb61fe1a79866376003.zip |
[SPARK-16875][SQL] Add args checking for DataSet randomSplit and sample
## What changes were proposed in this pull request?
Add the missing args-checking for randomSplit and sample
## How was this patch tested?
unit tests
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Closes #14478 from zhengruifeng/fix_randomSplit.
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/RDD.scala | 37 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 14 |
2 files changed, 37 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a4905dd51b..2ee13dc4db 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -474,12 +474,17 @@ abstract class RDD[T: ClassTag]( def sample( withReplacement: Boolean, fraction: Double, - seed: Long = Utils.random.nextLong): RDD[T] = withScope { - require(fraction >= 0.0, "Negative fraction value: " + fraction) - if (withReplacement) { - new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) - } else { - new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed) + seed: Long = Utils.random.nextLong): RDD[T] = { + require(fraction >= 0, + s"Fraction must be nonnegative, but got ${fraction}") + + withScope { + require(fraction >= 0.0, "Negative fraction value: " + fraction) + if (withReplacement) { + new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) + } else { + new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed) + } } } @@ -493,14 +498,22 @@ abstract class RDD[T: ClassTag]( */ def randomSplit( weights: Array[Double], - seed: Long = Utils.random.nextLong): Array[RDD[T]] = withScope { - val sum = weights.sum - val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) - normalizedCumWeights.sliding(2).map { x => - randomSampleWithRange(x(0), x(1), seed) - }.toArray + seed: Long = Utils.random.nextLong): Array[RDD[T]] = { + require(weights.forall(_ >= 0), + s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}") + require(weights.sum > 0, + s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}") + + withScope { + val sum = weights.sum + val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) + normalizedCumWeights.sliding(2).map { x => + randomSampleWithRange(x(0), x(1), seed) + }.toArray + } } + /** * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability * range. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 306ca773d4..263ee33742 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1544,8 +1544,13 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { + require(fraction >= 0, + s"Fraction must be nonnegative, but got ${fraction}") + + withTypedPlan { + Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + } } /** @@ -1573,6 +1578,11 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { + require(weights.forall(_ >= 0), + s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}") + require(weights.sum > 0, + s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}") + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the |