diff options
author | Rekha Joshi <rekhajoshm@gmail.com> | 2016-04-11 17:13:30 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-04-11 17:13:30 +0800 |
commit | e82d95bf63f57cefa02dc545ceb451ecdeedce28 (patch) | |
tree | df7113377e2ed5140424f381e94010d98638438f | |
parent | 1a0cca1fc81512d480ed0efc46114cb2b2189183 (diff) | |
download | spark-e82d95bf63f57cefa02dc545ceb451ecdeedce28.tar.gz spark-e82d95bf63f57cefa02dc545ceb451ecdeedce28.tar.bz2 spark-e82d95bf63f57cefa02dc545ceb451ecdeedce28.zip |
[SPARK-14372][SQL] Dataset.randomSplit() needs a Java version
## What changes were proposed in this pull request?
1.Added method randomSplitAsList() in Dataset for java
for https://issues.apache.org/jira/browse/SPARK-14372
## How was this patch tested?
TestSuite
Author: Rekha Joshi <rekhajoshm@gmail.com>
Author: Joshi <rekhajoshm@gmail.com>
Closes #12184 from rekhajoshm/SPARK-14372.
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 17 | ||||
-rw-r--r-- | sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 10 |
2 files changed, 26 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2f6d8d109f..e216945fbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -1493,6 +1492,8 @@ class Dataset[T] private[sql]( * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. * + * For Java API, use [[randomSplitAsList]]. + * * @group typedrel * @since 2.0.0 */ @@ -1511,6 +1512,20 @@ class Dataset[T] private[sql]( } /** + * Returns a Java list that contains randomly split [[Dataset]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * + * @group typedrel + * @since 2.0.0 + */ + def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = { + val values = randomSplit(weights, seed) + java.util.Arrays.asList(values : _*) + } + + /** * Randomly splits this [[Dataset]] with the provided weights. * * @param weights weights for splits, will be normalized if they don't sum to 1. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f26c57b301..5abd62cbc2 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -454,6 +454,16 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(data, ds.collectAsList()); } + @Test + public void testRandomSplit() { + List<String> data = Arrays.asList("hello", "world", "from", "spark"); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + double[] arraySplit = {1, 2, 3}; + + List<Dataset<String>> randomSplit = ds.randomSplitAsList(arraySplit, 1); + Assert.assertEquals("wrong number of splits", randomSplit.size(), 3); + } + /** * For testing error messages when creating an encoder on a private class. This is done * here since we cannot create truly private classes in Scala. |