From 97b7080cf2d2846c7257f8926f775f27d457fe7d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 8 Nov 2015 20:57:09 -0800 Subject: [SPARK-11564][SQL] Dataset Java API audit A few changes: 1. Removed fold, since it can be confusing for distributed collections. 2. Created specific interfaces for each Dataset function (e.g. MapFunction, ReduceFunction, MapPartitionsFunction) 3. Added more documentation and test cases. The other thing I'm considering doing is to have a "collector" interface for FlatMapFunction and MapPartitionsFunction, similar to MapReduce's map function. Author: Reynold Xin Closes #9531 from rxin/SPARK-11564. --- .../spark/sql/catalyst/encoders/Encoder.scala | 38 ++++++-- .../scala/org/apache/spark/sql/DataFrame.scala | 47 +++++++--- .../main/scala/org/apache/spark/sql/Dataset.scala | 100 ++++++++++----------- .../org/apache/spark/sql/JavaDataFrameSuite.java | 7 ++ .../org/apache/spark/sql/JavaDatasetSuite.java | 36 ++++---- .../apache/spark/sql/DatasetPrimitiveSuite.scala | 5 -- .../scala/org/apache/spark/sql/DatasetSuite.scala | 10 +-- 7 files changed, 147 insertions(+), 96 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index f05e18288d..6569b900fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag import org.apache.spark.util.Utils -import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType} +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} import org.apache.spark.sql.catalyst.expressions._ /** @@ -100,7 +100,7 @@ object Encoder { expr.transformUp { case BoundReference(0, t: ObjectType, _) => Invoke( - BoundReference(0, ObjectType(cls), true), + BoundReference(0, ObjectType(cls), nullable = true), s"_${index + 1}", t) } @@ -114,13 +114,13 @@ object Encoder { } else { enc.constructExpression.transformUp { case BoundReference(ordinal, dt, _) => - GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt) + GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt) } } } val constructExpression = - NewInstance(cls, constructExpressions, false, ObjectType(cls)) + NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls)) new ExpressionEncoder[Any]( schema, @@ -130,7 +130,6 @@ object Encoder { ClassTag.apply(cls)) } - def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] private def getTypeTag[T](c: Class[T]): TypeTag[T] = { @@ -148,9 +147,36 @@ object Encoder { }) } - def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { + def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { implicit val typeTag1 = getTypeTag(c1) implicit val typeTag2 = getTypeTag(c2) ExpressionEncoder[(T1, T2)]() } + + def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + implicit val typeTag3 = getTypeTag(c3) + ExpressionEncoder[(T1, T2, T3)]() + } + + def forTuple[T1, T2, T3, T4]( + c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + implicit val typeTag3 = getTypeTag(c3) + implicit val typeTag4 = getTypeTag(c4) + ExpressionEncoder[(T1, T2, T3, T4)]() + } + + def forTuple[T1, T2, T3, T4, T5]( + c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5]) + : Encoder[(T1, T2, T3, T4, T5)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + implicit val typeTag3 = getTypeTag(c3) + implicit val typeTag4 = getTypeTag(c4) + implicit val typeTag5 = getTypeTag(c5) + ExpressionEncoder[(T1, T2, T3, T4, T5)]() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f2d4db5550..8ab958adad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1478,18 +1478,54 @@ class DataFrame private[sql]( /** * Returns the first `n` rows in the [[DataFrame]]. + * + * Running take requires moving data into the application's driver process, and doing so on a + * very large dataset can crash the driver process with OutOfMemoryError. + * * @group action * @since 1.3.0 */ def take(n: Int): Array[Row] = head(n) + /** + * Returns the first `n` rows in the [[DataFrame]] as a list. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @group action + * @since 1.6.0 + */ + def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*) + /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. + * * @group action * @since 1.3.0 */ def collect(): Array[Row] = collect(needCallback = true) + /** + * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * @group action + * @since 1.3.0 + */ + def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => + withNewExecutionId { + java.util.Arrays.asList(rdd.collect() : _*) + } + } + private def collect(needCallback: Boolean): Array[Row] = { def execute(): Array[Row] = withNewExecutionId { queryExecution.executedPlan.executeCollectPublic() @@ -1502,17 +1538,6 @@ class DataFrame private[sql]( } } - /** - * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. - * @group action - * @since 1.3.0 - */ - def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => - withNewExecutionId { - java.util.Arrays.asList(rdd.collect() : _*) - } - } - /** * Returns the number of rows in the [[DataFrame]]. * @group action diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fecbdac9a6..959e0f5ba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} +import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ @@ -75,7 +75,11 @@ class Dataset[T] private[sql]( private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = this(sqlContext, new QueryExecution(sqlContext, plan), encoder) - /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */ + /** + * Returns the schema of the encoded form of the objects in this [[Dataset]]. + * + * @since 1.6.0 + */ def schema: StructType = encoder.schema /* ************* * @@ -103,6 +107,7 @@ class Dataset[T] private[sql]( /** * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have * the same name after two Datasets have been joined. + * @since 1.6.0 */ def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _)) @@ -166,8 +171,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. * @since 1.6.0 */ - def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] = - filter(t => func.call(t).booleanValue()) + def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) /** * (Scala-specific) @@ -181,7 +185,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ - def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] = + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = map(t => func.call(t))(encoder) /** @@ -205,10 +209,8 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ - def mapPartitions[U]( - f: FlatMapFunction[java.util.Iterator[T], U], - encoder: Encoder[U]): Dataset[U] = { - val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala + def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala mapPartitions(func)(encoder) } @@ -248,7 +250,7 @@ class Dataset[T] private[sql]( * Runs `func` on each element of this Dataset. * @since 1.6.0 */ - def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_)) + def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) /** * (Scala-specific) @@ -262,7 +264,7 @@ class Dataset[T] private[sql]( * Runs `func` on each partition of this Dataset. * @since 1.6.0 */ - def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit = + def foreachPartition(func: ForeachPartitionFunction[T]): Unit = foreachPartition(it => func.call(it.asJava)) /* ************* * @@ -271,7 +273,7 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ @@ -279,33 +281,11 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ - def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _)) - - /** - * (Scala-specific) - * Aggregates the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". - * - * This behaves somewhat differently than the fold operations implemented for non-distributed - * collections in functional languages like Scala. This fold operation may be applied to - * partitions individually, and then those results will be folded into the final result. - * If op is not commutative, then the result may differ from that of a fold applied to a - * non-distributed collection. - * @since 1.6.0 - */ - def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) - - /** - * (Java-specific) - * Aggregates the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". - * @since 1.6.0 - */ - def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _)) + def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) /** * (Scala-specific) @@ -351,7 +331,7 @@ class Dataset[T] private[sql]( * Returns a [[GroupedDataset]] where the data is grouped by the given key function. * @since 1.6.0 */ - def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = groupBy(f.call(_))(encoder) /* ****************** * @@ -367,7 +347,7 @@ class Dataset[T] private[sql]( */ // Copied from Dataframe to make sure we don't have invalid overloads. @scala.annotation.varargs - def select(cols: Column*): DataFrame = toDF().select(cols: _*) + protected def select(cols: Column*): DataFrame = toDF().select(cols: _*) /** * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. @@ -462,8 +442,7 @@ class Dataset[T] private[sql]( * and thus is not affected by a custom `equals` function defined on `T`. * @since 1.6.0 */ - def intersect(other: Dataset[T]): Dataset[T] = - withPlan[T](other)(Intersect) + def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect) /** * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]] @@ -473,8 +452,7 @@ class Dataset[T] private[sql]( * duplicate items. As such, it is analagous to `UNION ALL` in SQL. * @since 1.6.0 */ - def union(other: Dataset[T]): Dataset[T] = - withPlan[T](other)(Union) + def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) /** * Returns a new [[Dataset]] where any elements present in `other` have been removed. @@ -542,27 +520,47 @@ class Dataset[T] private[sql]( def first(): T = rdd.first() /** - * Collects the elements to an Array. + * Returns an array that contains all the elements in this [[Dataset]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. * @since 1.6.0 */ def collect(): Array[T] = rdd.collect() /** - * (Java-specific) - * Collects the elements to a Java list. + * Returns an array that contains all the elements in this [[Dataset]]. * - * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at - * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method - * instead and keep the generic type for result. + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. * + * For Java API, use [[collectAsList]]. * @since 1.6.0 */ - def collectAsList(): java.util.List[T] = - rdd.collect().toSeq.asJava + def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava - /** Returns the first `num` elements of this [[Dataset]] as an Array. */ + /** + * Returns the first `num` elements of this [[Dataset]] as an array. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @since 1.6.0 + */ def take(num: Int): Array[T] = rdd.take(num) + /** + * Returns the first `num` elements of this [[Dataset]] as an array. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @since 1.6.0 + */ + def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + /* ******************** * * Internal Functions * * ******************** */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 40bff57a17..d191b50fa8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -65,6 +65,13 @@ public class JavaDataFrameSuite { Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } + @Test + public void testCollectAndTake() { + DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Assert.assertEquals(3, df.select("key").collectAsList().size()); + Assert.assertEquals(2, df.select("key").takeAsList(2).size()); + } + /** * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0d3b1a5af5..0f90de774d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -68,8 +68,16 @@ public class JavaDatasetSuite implements Serializable { public void testCollect() { List data = Arrays.asList("hello", "world"); Dataset ds = context.createDataset(data, e.STRING()); - String[] collected = (String[]) ds.collect(); - Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected)); + List collected = ds.collectAsList(); + Assert.assertEquals(Arrays.asList("hello", "world"), collected); + } + + @Test + public void testTake() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, e.STRING()); + List collected = ds.takeAsList(1); + Assert.assertEquals(Arrays.asList("hello"), collected); } @Test @@ -78,16 +86,16 @@ public class JavaDatasetSuite implements Serializable { Dataset ds = context.createDataset(data, e.STRING()); Assert.assertEquals("hello", ds.first()); - Dataset filtered = ds.filter(new Function() { + Dataset filtered = ds.filter(new FilterFunction() { @Override - public Boolean call(String v) throws Exception { + public boolean call(String v) throws Exception { return v.startsWith("h"); } }); Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); - Dataset mapped = ds.map(new Function() { + Dataset mapped = ds.map(new MapFunction() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -95,7 +103,7 @@ public class JavaDatasetSuite implements Serializable { }, e.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); - Dataset parMapped = ds.mapPartitions(new FlatMapFunction, String>() { + Dataset parMapped = ds.mapPartitions(new MapPartitionsFunction() { @Override public Iterable call(Iterator it) throws Exception { List ls = new LinkedList(); @@ -128,7 +136,7 @@ public class JavaDatasetSuite implements Serializable { List data = Arrays.asList("a", "b", "c"); Dataset ds = context.createDataset(data, e.STRING()); - ds.foreach(new VoidFunction() { + ds.foreach(new ForeachFunction() { @Override public void call(String s) throws Exception { accum.add(1); @@ -142,28 +150,20 @@ public class JavaDatasetSuite implements Serializable { List data = Arrays.asList(1, 2, 3); Dataset ds = context.createDataset(data, e.INT()); - int reduced = ds.reduce(new Function2() { + int reduced = ds.reduce(new ReduceFunction() { @Override public Integer call(Integer v1, Integer v2) throws Exception { return v1 + v2; } }); Assert.assertEquals(6, reduced); - - int folded = ds.fold(1, new Function2() { - @Override - public Integer call(Integer v1, Integer v2) throws Exception { - return v1 * v2; - } - }); - Assert.assertEquals(6, folded); } @Test public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, e.STRING()); - GroupedDataset grouped = ds.groupBy(new Function() { + GroupedDataset grouped = ds.groupBy(new MapFunction() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -187,7 +187,7 @@ public class JavaDatasetSuite implements Serializable { List data2 = Arrays.asList(2, 6, 10); Dataset ds2 = context.createDataset(data2, e.INT()); - GroupedDataset grouped2 = ds2.groupBy(new Function() { + GroupedDataset grouped2 = ds2.groupBy(new MapFunction() { @Override public Integer call(Integer v) throws Exception { return v / 2; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index fcf03f7180..63b00975e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -75,11 +75,6 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { assert(ds.reduce(_ + _) == 6) } - test("fold") { - val ds = Seq(1, 2, 3).toDS() - assert(ds.fold(0)(_ + _) == 6) - } - test("groupBy function, keys") { val ds = Seq(1, 2, 3, 4, 5).toDS() val grouped = ds.groupBy(_ % 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6f1174e657..aea5a700d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -61,6 +61,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))) } + test("as case class - take") { + val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData] + assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2))) + } + test("map") { val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( @@ -137,11 +142,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) } - test("fold") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) - } - test("joinWith, flat schema") { val ds1 = Seq(1, 2, 3).toDS().as("a") val ds2 = Seq(1, 2).toDS().as("b") -- cgit v1.2.3