aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-08-04 21:39:45 +0100
committerSean Owen <sowen@cloudera.com>2016-08-04 21:39:45 +0100
commitbe8ea4b2f7ddf1196111acb61fe1a79866376003 (patch)
tree4d5f36a72aef3dc8e01a327be1aeb1cd3fc01172
parentac2a26d09e10c3f462ec773c3ebaa6eedae81ac0 (diff)
downloadspark-be8ea4b2f7ddf1196111acb61fe1a79866376003.tar.gz
spark-be8ea4b2f7ddf1196111acb61fe1a79866376003.tar.bz2
spark-be8ea4b2f7ddf1196111acb61fe1a79866376003.zip
[SPARK-16875][SQL] Add args checking for DataSet randomSplit and sample
## What changes were proposed in this pull request? Add the missing args-checking for randomSplit and sample ## How was this patch tested? unit tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #14478 from zhengruifeng/fix_randomSplit.
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala14
2 files changed, 37 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a4905dd51b..2ee13dc4db 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -474,12 +474,17 @@ abstract class RDD[T: ClassTag](
def sample(
withReplacement: Boolean,
fraction: Double,
- seed: Long = Utils.random.nextLong): RDD[T] = withScope {
- require(fraction >= 0.0, "Negative fraction value: " + fraction)
- if (withReplacement) {
- new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
- } else {
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
+ seed: Long = Utils.random.nextLong): RDD[T] = {
+ require(fraction >= 0,
+ s"Fraction must be nonnegative, but got ${fraction}")
+
+ withScope {
+ require(fraction >= 0.0, "Negative fraction value: " + fraction)
+ if (withReplacement) {
+ new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
+ } else {
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
+ }
}
}
@@ -493,14 +498,22 @@ abstract class RDD[T: ClassTag](
*/
def randomSplit(
weights: Array[Double],
- seed: Long = Utils.random.nextLong): Array[RDD[T]] = withScope {
- val sum = weights.sum
- val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
- normalizedCumWeights.sliding(2).map { x =>
- randomSampleWithRange(x(0), x(1), seed)
- }.toArray
+ seed: Long = Utils.random.nextLong): Array[RDD[T]] = {
+ require(weights.forall(_ >= 0),
+ s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}")
+ require(weights.sum > 0,
+ s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}")
+
+ withScope {
+ val sum = weights.sum
+ val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
+ normalizedCumWeights.sliding(2).map { x =>
+ randomSampleWithRange(x(0), x(1), seed)
+ }.toArray
+ }
}
+
/**
* Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
* range.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 306ca773d4..263ee33742 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1544,8 +1544,13 @@ class Dataset[T] private[sql](
* @group typedrel
* @since 1.6.0
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan {
- Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
+ require(fraction >= 0,
+ s"Fraction must be nonnegative, but got ${fraction}")
+
+ withTypedPlan {
+ Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
+ }
}
/**
@@ -1573,6 +1578,11 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = {
+ require(weights.forall(_ >= 0),
+ s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}")
+ require(weights.sum > 0,
+ s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}")
+
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the