aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-08-04 21:39:45 +0100
committerSean Owen <sowen@cloudera.com>2016-08-04 21:39:45 +0100
commitbe8ea4b2f7ddf1196111acb61fe1a79866376003 (patch)
tree4d5f36a72aef3dc8e01a327be1aeb1cd3fc01172 /sql
parentac2a26d09e10c3f462ec773c3ebaa6eedae81ac0 (diff)
downloadspark-be8ea4b2f7ddf1196111acb61fe1a79866376003.tar.gz
spark-be8ea4b2f7ddf1196111acb61fe1a79866376003.tar.bz2
spark-be8ea4b2f7ddf1196111acb61fe1a79866376003.zip
[SPARK-16875][SQL] Add args checking for DataSet randomSplit and sample
## What changes were proposed in this pull request? Add the missing args-checking for randomSplit and sample ## How was this patch tested? unit tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #14478 from zhengruifeng/fix_randomSplit.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala14
1 files changed, 12 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 306ca773d4..263ee33742 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1544,8 +1544,13 @@ class Dataset[T] private[sql](
* @group typedrel
* @since 1.6.0
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan {
- Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
+ require(fraction >= 0,
+ s"Fraction must be nonnegative, but got ${fraction}")
+
+ withTypedPlan {
+ Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
+ }
}
/**
@@ -1573,6 +1578,11 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = {
+ require(weights.forall(_ >= 0),
+ s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}")
+ require(weights.sum > 0,
+ s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}")
+
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the