diff options
author | Zheng RuiFeng <ruifengz@foxmail.com> | 2016-08-04 21:39:45 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-08-04 21:39:45 +0100 |
commit | be8ea4b2f7ddf1196111acb61fe1a79866376003 (patch) | |
tree | 4d5f36a72aef3dc8e01a327be1aeb1cd3fc01172 /sql | |
parent | ac2a26d09e10c3f462ec773c3ebaa6eedae81ac0 (diff) | |
download | spark-be8ea4b2f7ddf1196111acb61fe1a79866376003.tar.gz spark-be8ea4b2f7ddf1196111acb61fe1a79866376003.tar.bz2 spark-be8ea4b2f7ddf1196111acb61fe1a79866376003.zip |
[SPARK-16875][SQL] Add args checking for DataSet randomSplit and sample
## What changes were proposed in this pull request?
Add the missing args-checking for randomSplit and sample
## How was this patch tested?
unit tests
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Closes #14478 from zhengruifeng/fix_randomSplit.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 306ca773d4..263ee33742 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1544,8 +1544,13 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { + require(fraction >= 0, + s"Fraction must be nonnegative, but got ${fraction}") + + withTypedPlan { + Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + } } /** @@ -1573,6 +1578,11 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { + require(weights.forall(_ >= 0), + s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}") + require(weights.sum > 0, + s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}") + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the |