aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorRekha Joshi <rekhajoshm@gmail.com>2016-04-11 17:13:30 +0800
committerCheng Lian <lian@databricks.com>2016-04-11 17:13:30 +0800
commite82d95bf63f57cefa02dc545ceb451ecdeedce28 (patch)
treedf7113377e2ed5140424f381e94010d98638438f /sql
parent1a0cca1fc81512d480ed0efc46114cb2b2189183 (diff)
downloadspark-e82d95bf63f57cefa02dc545ceb451ecdeedce28.tar.gz
spark-e82d95bf63f57cefa02dc545ceb451ecdeedce28.tar.bz2
spark-e82d95bf63f57cefa02dc545ceb451ecdeedce28.zip
[SPARK-14372][SQL] Dataset.randomSplit() needs a Java version
## What changes were proposed in this pull request? 1.Added method randomSplitAsList() in Dataset for java for https://issues.apache.org/jira/browse/SPARK-14372 ## How was this patch tested? TestSuite Author: Rekha Joshi <rekhajoshm@gmail.com> Author: Joshi <rekhajoshm@gmail.com> Closes #12184 from rekhajoshm/SPARK-14372.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala17
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java10
2 files changed, 26 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2f6d8d109f..e216945fbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -1493,6 +1492,8 @@ class Dataset[T] private[sql](
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
*
+ * For Java API, use [[randomSplitAsList]].
+ *
* @group typedrel
* @since 2.0.0
*/
@@ -1511,6 +1512,20 @@ class Dataset[T] private[sql](
}
/**
+ * Returns a Java list that contains randomly split [[Dataset]] with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @param seed Seed for sampling.
+ *
+ * @group typedrel
+ * @since 2.0.0
+ */
+ def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = {
+ val values = randomSplit(weights, seed)
+ java.util.Arrays.asList(values : _*)
+ }
+
+ /**
* Randomly splits this [[Dataset]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index f26c57b301..5abd62cbc2 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -454,6 +454,16 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(data, ds.collectAsList());
}
+ @Test
+ public void testRandomSplit() {
+ List<String> data = Arrays.asList("hello", "world", "from", "spark");
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ double[] arraySplit = {1, 2, 3};
+
+ List<Dataset<String>> randomSplit = ds.randomSplitAsList(arraySplit, 1);
+ Assert.assertEquals("wrong number of splits", randomSplit.size(), 3);
+ }
+
/**
* For testing error messages when creating an encoder on a private class. This is done
* here since we cannot create truly private classes in Scala.