aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-11-08 20:57:09 -0800
committerReynold Xin <rxin@databricks.com>2015-11-08 20:57:09 -0800
commit97b7080cf2d2846c7257f8926f775f27d457fe7d (patch)
tree28efd3ca15c2e96c0d4f0b5d08cabb9e602ef12e /sql
parentb2d195e137fad88d567974659fa7023ff4da96cd (diff)
downloadspark-97b7080cf2d2846c7257f8926f775f27d457fe7d.tar.gz
spark-97b7080cf2d2846c7257f8926f775f27d457fe7d.tar.bz2
spark-97b7080cf2d2846c7257f8926f775f27d457fe7d.zip
[SPARK-11564][SQL] Dataset Java API audit
A few changes: 1. Removed fold, since it can be confusing for distributed collections. 2. Created specific interfaces for each Dataset function (e.g. MapFunction, ReduceFunction, MapPartitionsFunction) 3. Added more documentation and test cases. The other thing I'm considering doing is to have a "collector" interface for FlatMapFunction and MapPartitionsFunction, similar to MapReduce's map function. Author: Reynold Xin <rxin@databricks.com> Closes #9531 from rxin/SPARK-11564.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala47
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala100
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java7
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala10
7 files changed, 147 insertions, 96 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index f05e18288d..6569b900fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.reflect.ClassTag
import org.apache.spark.util.Utils
-import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType}
+import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
import org.apache.spark.sql.catalyst.expressions._
/**
@@ -100,7 +100,7 @@ object Encoder {
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
Invoke(
- BoundReference(0, ObjectType(cls), true),
+ BoundReference(0, ObjectType(cls), nullable = true),
s"_${index + 1}",
t)
}
@@ -114,13 +114,13 @@ object Encoder {
} else {
enc.constructExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
- GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt)
+ GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
}
}
}
val constructExpression =
- NewInstance(cls, constructExpressions, false, ObjectType(cls))
+ NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls))
new ExpressionEncoder[Any](
schema,
@@ -130,7 +130,6 @@ object Encoder {
ClassTag.apply(cls))
}
-
def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
@@ -148,9 +147,36 @@ object Encoder {
})
}
- def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
+ def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
ExpressionEncoder[(T1, T2)]()
}
+
+ def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = {
+ implicit val typeTag1 = getTypeTag(c1)
+ implicit val typeTag2 = getTypeTag(c2)
+ implicit val typeTag3 = getTypeTag(c3)
+ ExpressionEncoder[(T1, T2, T3)]()
+ }
+
+ def forTuple[T1, T2, T3, T4](
+ c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = {
+ implicit val typeTag1 = getTypeTag(c1)
+ implicit val typeTag2 = getTypeTag(c2)
+ implicit val typeTag3 = getTypeTag(c3)
+ implicit val typeTag4 = getTypeTag(c4)
+ ExpressionEncoder[(T1, T2, T3, T4)]()
+ }
+
+ def forTuple[T1, T2, T3, T4, T5](
+ c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5])
+ : Encoder[(T1, T2, T3, T4, T5)] = {
+ implicit val typeTag1 = getTypeTag(c1)
+ implicit val typeTag2 = getTypeTag(c2)
+ implicit val typeTag3 = getTypeTag(c3)
+ implicit val typeTag4 = getTypeTag(c4)
+ implicit val typeTag5 = getTypeTag(c5)
+ ExpressionEncoder[(T1, T2, T3, T4, T5)]()
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index f2d4db5550..8ab958adad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1478,18 +1478,54 @@ class DataFrame private[sql](
/**
* Returns the first `n` rows in the [[DataFrame]].
+ *
+ * Running take requires moving data into the application's driver process, and doing so on a
+ * very large dataset can crash the driver process with OutOfMemoryError.
+ *
* @group action
* @since 1.3.0
*/
def take(n: Int): Array[Row] = head(n)
/**
+ * Returns the first `n` rows in the [[DataFrame]] as a list.
+ *
+ * Running take requires moving data into the application's driver process, and doing so with
+ * a very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * @group action
+ * @since 1.6.0
+ */
+ def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*)
+
+ /**
* Returns an array that contains all of [[Row]]s in this [[DataFrame]].
+ *
+ * Running take requires moving data into the application's driver process, and doing so with
+ * a very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * For Java API, use [[collectAsList]].
+ *
* @group action
* @since 1.3.0
*/
def collect(): Array[Row] = collect(needCallback = true)
+ /**
+ * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
+ *
+ * Running collect requires moving all the data into the application's driver process, and
+ * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+ *
+ * @group action
+ * @since 1.3.0
+ */
+ def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
+ withNewExecutionId {
+ java.util.Arrays.asList(rdd.collect() : _*)
+ }
+ }
+
private def collect(needCallback: Boolean): Array[Row] = {
def execute(): Array[Row] = withNewExecutionId {
queryExecution.executedPlan.executeCollectPublic()
@@ -1503,17 +1539,6 @@ class DataFrame private[sql](
}
/**
- * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
- * @group action
- * @since 1.3.0
- */
- def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
- withNewExecutionId {
- java.util.Arrays.asList(rdd.collect() : _*)
- }
- }
-
- /**
* Returns the number of rows in the [[DataFrame]].
* @group action
* @since 1.3.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index fecbdac9a6..959e0f5ba0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
+import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
@@ -75,7 +75,11 @@ class Dataset[T] private[sql](
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
- /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
+ /**
+ * Returns the schema of the encoded form of the objects in this [[Dataset]].
+ *
+ * @since 1.6.0
+ */
def schema: StructType = encoder.schema
/* ************* *
@@ -103,6 +107,7 @@ class Dataset[T] private[sql](
/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
* the same name after two Datasets have been joined.
+ * @since 1.6.0
*/
def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _))
@@ -166,8 +171,7 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
* @since 1.6.0
*/
- def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] =
- filter(t => func.call(t).booleanValue())
+ def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
/**
* (Scala-specific)
@@ -181,7 +185,7 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
- def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+ def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
map(t => func.call(t))(encoder)
/**
@@ -205,10 +209,8 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
- def mapPartitions[U](
- f: FlatMapFunction[java.util.Iterator[T], U],
- encoder: Encoder[U]): Dataset[U] = {
- val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala
+ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala
mapPartitions(func)(encoder)
}
@@ -248,7 +250,7 @@ class Dataset[T] private[sql](
* Runs `func` on each element of this Dataset.
* @since 1.6.0
*/
- def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_))
+ def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
/**
* (Scala-specific)
@@ -262,7 +264,7 @@ class Dataset[T] private[sql](
* Runs `func` on each partition of this Dataset.
* @since 1.6.0
*/
- def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit =
+ def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
foreachPartition(it => func.call(it.asJava))
/* ************* *
@@ -271,7 +273,7 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given function
+ * Reduces the elements of this Dataset using the specified binary function. The given function
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
@@ -279,33 +281,11 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given function
+ * Reduces the elements of this Dataset using the specified binary function. The given function
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
- def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _))
-
- /**
- * (Scala-specific)
- * Aggregates the elements of each partition, and then the results for all the partitions, using a
- * given associative and commutative function and a neutral "zero value".
- *
- * This behaves somewhat differently than the fold operations implemented for non-distributed
- * collections in functional languages like Scala. This fold operation may be applied to
- * partitions individually, and then those results will be folded into the final result.
- * If op is not commutative, then the result may differ from that of a fold applied to a
- * non-distributed collection.
- * @since 1.6.0
- */
- def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op)
-
- /**
- * (Java-specific)
- * Aggregates the elements of each partition, and then the results for all the partitions, using a
- * given associative and commutative function and a neutral "zero value".
- * @since 1.6.0
- */
- def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _))
+ def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
/**
* (Scala-specific)
@@ -351,7 +331,7 @@ class Dataset[T] private[sql](
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* @since 1.6.0
*/
- def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
+ def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
groupBy(f.call(_))(encoder)
/* ****************** *
@@ -367,7 +347,7 @@ class Dataset[T] private[sql](
*/
// Copied from Dataframe to make sure we don't have invalid overloads.
@scala.annotation.varargs
- def select(cols: Column*): DataFrame = toDF().select(cols: _*)
+ protected def select(cols: Column*): DataFrame = toDF().select(cols: _*)
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
@@ -462,8 +442,7 @@ class Dataset[T] private[sql](
* and thus is not affected by a custom `equals` function defined on `T`.
* @since 1.6.0
*/
- def intersect(other: Dataset[T]): Dataset[T] =
- withPlan[T](other)(Intersect)
+ def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect)
/**
* Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]]
@@ -473,8 +452,7 @@ class Dataset[T] private[sql](
* duplicate items. As such, it is analagous to `UNION ALL` in SQL.
* @since 1.6.0
*/
- def union(other: Dataset[T]): Dataset[T] =
- withPlan[T](other)(Union)
+ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union)
/**
* Returns a new [[Dataset]] where any elements present in `other` have been removed.
@@ -542,27 +520,47 @@ class Dataset[T] private[sql](
def first(): T = rdd.first()
/**
- * Collects the elements to an Array.
+ * Returns an array that contains all the elements in this [[Dataset]].
+ *
+ * Running collect requires moving all the data into the application's driver process, and
+ * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+ *
+ * For Java API, use [[collectAsList]].
* @since 1.6.0
*/
def collect(): Array[T] = rdd.collect()
/**
- * (Java-specific)
- * Collects the elements to a Java list.
+ * Returns an array that contains all the elements in this [[Dataset]].
*
- * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at
- * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method
- * instead and keep the generic type for result.
+ * Running collect requires moving all the data into the application's driver process, and
+ * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
*
+ * For Java API, use [[collectAsList]].
* @since 1.6.0
*/
- def collectAsList(): java.util.List[T] =
- rdd.collect().toSeq.asJava
+ def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava
- /** Returns the first `num` elements of this [[Dataset]] as an Array. */
+ /**
+ * Returns the first `num` elements of this [[Dataset]] as an array.
+ *
+ * Running take requires moving data into the application's driver process, and doing so with
+ * a very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * @since 1.6.0
+ */
def take(num: Int): Array[T] = rdd.take(num)
+ /**
+ * Returns the first `num` elements of this [[Dataset]] as an array.
+ *
+ * Running take requires moving data into the application's driver process, and doing so with
+ * a very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * @since 1.6.0
+ */
+ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
+
/* ******************** *
* Internal Functions *
* ******************** */
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 40bff57a17..d191b50fa8 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -65,6 +65,13 @@ public class JavaDataFrameSuite {
Assert.assertEquals(1, df.select("key").collect()[0].get(0));
}
+ @Test
+ public void testCollectAndTake() {
+ DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
+ Assert.assertEquals(3, df.select("key").collectAsList().size());
+ Assert.assertEquals(2, df.select("key").takeAsList(2).size());
+ }
+
/**
* See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java.
*/
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0d3b1a5af5..0f90de774d 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -68,8 +68,16 @@ public class JavaDatasetSuite implements Serializable {
public void testCollect() {
List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, e.STRING());
- String[] collected = (String[]) ds.collect();
- Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected));
+ List<String> collected = ds.collectAsList();
+ Assert.assertEquals(Arrays.asList("hello", "world"), collected);
+ }
+
+ @Test
+ public void testTake() {
+ List<String> data = Arrays.asList("hello", "world");
+ Dataset<String> ds = context.createDataset(data, e.STRING());
+ List<String> collected = ds.takeAsList(1);
+ Assert.assertEquals(Arrays.asList("hello"), collected);
}
@Test
@@ -78,16 +86,16 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> ds = context.createDataset(data, e.STRING());
Assert.assertEquals("hello", ds.first());
- Dataset<String> filtered = ds.filter(new Function<String, Boolean>() {
+ Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
@Override
- public Boolean call(String v) throws Exception {
+ public boolean call(String v) throws Exception {
return v.startsWith("h");
}
});
Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList());
- Dataset<Integer> mapped = ds.map(new Function<String, Integer>() {
+ Dataset<Integer> mapped = ds.map(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@@ -95,7 +103,7 @@ public class JavaDatasetSuite implements Serializable {
}, e.INT());
Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
- Dataset<String> parMapped = ds.mapPartitions(new FlatMapFunction<Iterator<String>, String>() {
+ Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() {
@Override
public Iterable<String> call(Iterator<String> it) throws Exception {
List<String> ls = new LinkedList<String>();
@@ -128,7 +136,7 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("a", "b", "c");
Dataset<String> ds = context.createDataset(data, e.STRING());
- ds.foreach(new VoidFunction<String>() {
+ ds.foreach(new ForeachFunction<String>() {
@Override
public void call(String s) throws Exception {
accum.add(1);
@@ -142,28 +150,20 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data = Arrays.asList(1, 2, 3);
Dataset<Integer> ds = context.createDataset(data, e.INT());
- int reduced = ds.reduce(new Function2<Integer, Integer, Integer>() {
+ int reduced = ds.reduce(new ReduceFunction<Integer>() {
@Override
public Integer call(Integer v1, Integer v2) throws Exception {
return v1 + v2;
}
});
Assert.assertEquals(6, reduced);
-
- int folded = ds.fold(1, new Function2<Integer, Integer, Integer>() {
- @Override
- public Integer call(Integer v1, Integer v2) throws Exception {
- return v1 * v2;
- }
- });
- Assert.assertEquals(6, folded);
}
@Test
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, e.STRING());
- GroupedDataset<Integer, String> grouped = ds.groupBy(new Function<String, Integer>() {
+ GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@@ -187,7 +187,7 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
- GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new Function<Integer, Integer>() {
+ GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
@Override
public Integer call(Integer v) throws Exception {
return v / 2;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index fcf03f7180..63b00975e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -75,11 +75,6 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
assert(ds.reduce(_ + _) == 6)
}
- test("fold") {
- val ds = Seq(1, 2, 3).toDS()
- assert(ds.fold(0)(_ + _) == 6)
- }
-
test("groupBy function, keys") {
val ds = Seq(1, 2, 3, 4, 5).toDS()
val grouped = ds.groupBy(_ % 2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6f1174e657..aea5a700d0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -61,6 +61,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)))
}
+ test("as case class - take") {
+ val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData]
+ assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2)))
+ }
+
test("map") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
checkAnswer(
@@ -137,11 +142,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
}
- test("fold") {
- val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
- assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
- }
-
test("joinWith, flat schema") {
val ds1 = Seq(1, 2, 3).toDS().as("a")
val ds2 = Seq(1, 2).toDS().as("b")