From 7e9a9e603abce8689938bdd62d04b29299644aa4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 6 Nov 2015 15:37:07 -0800 Subject: [SPARK-11269][SQL] Java API support & test cases for Dataset This simply brings https://github.com/apache/spark/pull/9358 up-to-date. Author: Wenchen Fan Author: Reynold Xin Closes #9528 from rxin/dataset-java. --- .../spark/sql/catalyst/encoders/Encoder.scala | 123 ++++++- .../spark/sql/catalyst/expressions/objects.scala | 21 ++ .../main/scala/org/apache/spark/sql/Dataset.scala | 126 +++++++- .../scala/org/apache/spark/sql/DatasetHolder.scala | 6 +- .../org/apache/spark/sql/GroupedDataset.scala | 17 + .../scala/org/apache/spark/sql/SQLContext.scala | 4 + .../org/apache/spark/sql/JavaDatasetSuite.java | 357 +++++++++++++++++++++ .../apache/spark/sql/DatasetPrimitiveSuite.scala | 2 +- 8 files changed, 644 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index 329a132d3d..f05e18288d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.encoders - - import scala.reflect.ClassTag -import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils +import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType} +import org.apache.spark.sql.catalyst.expressions._ /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -37,3 +37,120 @@ trait Encoder[T] extends Serializable { /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ def clsTag: ClassTag[T] } + +object Encoder { + import scala.reflect.runtime.universe._ + + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + + def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { + tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2)]] + } + + def tuple[T1, T2, T3]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + } + + def tuple[T1, T2, T3, T4]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + } + + def tuple[T1, T2, T3, T4, T5]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4], + enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] + } + + private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + assert(encoders.length > 1) + // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. + assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) + + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) + + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + + val extractExpressions = encoders.map { + case e if e.flat => e.extractExpressions.head + case other => CreateStruct(other.extractExpressions) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t: ObjectType, _) => + Invoke( + BoundReference(0, ObjectType(cls), true), + s"_${index + 1}", + t) + } + } + + val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.constructExpression.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + enc.constructExpression.transformUp { + case BoundReference(ordinal, dt, _) => + GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt) + } + } + } + + val constructExpression = + NewInstance(cls, constructExpressions, false, ObjectType(cls)) + + new ExpressionEncoder[Any]( + schema, + false, + extractExpressions, + constructExpression, + ClassTag.apply(cls)) + } + + + def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] + + private def getTypeTag[T](c: Class[T]): TypeTag[T] = { + import scala.reflect.api + + // val mirror = runtimeMirror(c.getClassLoader) + val mirror = rootMirror + val sym = mirror.staticClass(c.getName) + val tpe = sym.selfType + TypeTag(mirror, new api.TypeCreator { + def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) = + if (m eq mirror) tpe.asInstanceOf[U # Type] + else throw new IllegalArgumentException( + s"Type tag defined in $mirror cannot be migrated to other mirrors.") + }) + } + + def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + ExpressionEncoder[(T1, T2)]() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 8185528976..4f58464221 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -491,3 +491,24 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);" } } + +case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType) + extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val row = child.gen(ctx) + s""" + ${row.code} + final boolean ${ev.isNull} = ${row.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)}; + } + """ + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4bca9c3b3f..fecbdac9a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} + import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner @@ -151,18 +155,37 @@ class Dataset[T] private[sql]( def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) /** + * (Scala-specific) * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. * @since 1.6.0 */ def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) /** + * (Java-specific) + * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * @since 1.6.0 + */ + def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] = + filter(t => func.call(t).booleanValue()) + + /** + * (Scala-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) /** + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] = + map(t => func.call(t))(encoder) + + /** + * (Scala-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ @@ -177,30 +200,77 @@ class Dataset[T] private[sql]( logicalPlan)) } + /** + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def mapPartitions[U]( + f: FlatMapFunction[java.util.Iterator[T], U], + encoder: Encoder[U]): Dataset[U] = { + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala + mapPartitions(func)(encoder) + } + + /** + * (Scala-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * @since 1.6.0 + */ def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) + /** + * (Java-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * @since 1.6.0 + */ + def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (T) => Iterable[U] = x => f.call(x).asScala + flatMap(func)(encoder) + } + /* ************** * * Side effects * * ************** */ /** + * (Scala-specific) * Runs `func` on each element of this Dataset. * @since 1.6.0 */ def foreach(func: T => Unit): Unit = rdd.foreach(func) /** + * (Java-specific) + * Runs `func` on each element of this Dataset. + * @since 1.6.0 + */ + def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_)) + + /** + * (Scala-specific) * Runs `func` on each partition of this Dataset. * @since 1.6.0 */ def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) + /** + * (Java-specific) + * Runs `func` on each partition of this Dataset. + * @since 1.6.0 + */ + def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit = + foreachPartition(it => func.call(it.asJava)) + /* ************* * * Aggregation * * ************* */ /** + * (Scala-specific) * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 @@ -208,6 +278,15 @@ class Dataset[T] private[sql]( def reduce(func: (T, T) => T): T = rdd.reduce(func) /** + * (Java-specific) + * Reduces the elements of this Dataset using the specified binary function. The given function + * must be commutative and associative or the result may be non-deterministic. + * @since 1.6.0 + */ + def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _)) + + /** + * (Scala-specific) * Aggregates the elements of each partition, and then the results for all the partitions, using a * given associative and commutative function and a neutral "zero value". * @@ -221,6 +300,15 @@ class Dataset[T] private[sql]( def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) /** + * (Java-specific) + * Aggregates the elements of each partition, and then the results for all the partitions, using a + * given associative and commutative function and a neutral "zero value". + * @since 1.6.0 + */ + def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _)) + + /** + * (Scala-specific) * Returns a [[GroupedDataset]] where the data is grouped by the given key function. * @since 1.6.0 */ @@ -258,6 +346,14 @@ class Dataset[T] private[sql]( keyAttributes) } + /** + * (Java-specific) + * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * @since 1.6.0 + */ + def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + groupBy(f.call(_))(encoder) + /* ****************** * * Typed Relational * * ****************** */ @@ -267,8 +363,7 @@ class Dataset[T] private[sql]( * {{{ * df.select($"colA", $"colB" + 1) * }}} - * @group dfops - * @since 1.3.0 + * @since 1.6.0 */ // Copied from Dataframe to make sure we don't have invalid overloads. @scala.annotation.varargs @@ -279,7 +374,7 @@ class Dataset[T] private[sql]( * * {{{ * val ds = Seq(1, 2, 3).toDS() - * val newDS = ds.select(e[Int]("value + 1")) + * val newDS = ds.select(expr("value + 1").as[Int]) * }}} * @since 1.6.0 */ @@ -405,6 +500,8 @@ class Dataset[T] private[sql]( * This type of join can be useful both for preserving type-safety with the original object * types as well as working with relational data where either side of the join has column * names in common. + * + * @since 1.6.0 */ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { val left = this.logicalPlan @@ -438,12 +535,31 @@ class Dataset[T] private[sql]( * Gather to Driver Actions * * ************************** */ - /** Returns the first element in this [[Dataset]]. */ + /** + * Returns the first element in this [[Dataset]]. + * @since 1.6.0 + */ def first(): T = rdd.first() - /** Collects the elements to an Array. */ + /** + * Collects the elements to an Array. + * @since 1.6.0 + */ def collect(): Array[T] = rdd.collect() + /** + * (Java-specific) + * Collects the elements to a Java list. + * + * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at + * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method + * instead and keep the generic type for result. + * + * @since 1.6.0 + */ + def collectAsList(): java.util.List[T] = + rdd.collect().toSeq.asJava + /** Returns the first `num` elements of this [[Dataset]] as an Array. */ def take(num: Int): Array[T] = rdd.take(num) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 45f0098b92..08097e9f02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -27,9 +27,9 @@ package org.apache.spark.sql * * @since 1.6.0 */ -case class DatasetHolder[T] private[sql](private val df: Dataset[T]) { +case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDS(): Dataset[T] = df + // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. + def toDS(): Dataset[T] = ds } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b8fc373dff..b2803d5a9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql +import java.util.{Iterator => JIterator} +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} @@ -104,6 +108,12 @@ class GroupedDataset[K, T] private[sql]( MapGroups(f, groupingAttributes, logicalPlan)) } + def mapGroups[U]( + f: JFunction2[K, JIterator[T], JIterator[U]], + encoder: Encoder[U]): Dataset[U] = { + mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + // To ensure valid overloading. protected def agg(expr: Column, exprs: Column*): DataFrame = groupedData.agg(expr, exprs: _*) @@ -196,4 +206,11 @@ class GroupedDataset[K, T] private[sql]( this.logicalPlan, other.logicalPlan)) } + + def cogroup[U, R]( + other: GroupedDataset[K, U], + f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]], + encoder: Encoder[R]): Dataset[R] = { + cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5ad3871093..5598731af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -508,6 +508,10 @@ class SQLContext private[sql]( new Dataset[T](this, plan) } + def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { + createDataset(data.asScala) + } + /** * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be * converted to Catalyst rows. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java new file mode 100644 index 0000000000..a9493d576d --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; +import java.util.*; + +import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; +import scala.Tuple5; +import org.junit.*; + +import org.apache.spark.Accumulator; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.catalyst.encoders.Encoder; +import org.apache.spark.sql.catalyst.encoders.Encoder$; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.GroupedDataset; +import org.apache.spark.sql.test.TestSQLContext; + +import static org.apache.spark.sql.functions.*; + +public class JavaDatasetSuite implements Serializable { + private transient JavaSparkContext jsc; + private transient TestSQLContext context; + private transient Encoder$ e = Encoder$.MODULE$; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + private Tuple2 tuple2(T1 t1, T2 t2) { + return new Tuple2(t1, t2); + } + + @Test + public void testCollect() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, e.STRING()); + String[] collected = (String[]) ds.collect(); + Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected)); + } + + @Test + public void testCommonOperation() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, e.STRING()); + Assert.assertEquals("hello", ds.first()); + + Dataset filtered = ds.filter(new Function() { + @Override + public Boolean call(String v) throws Exception { + return v.startsWith("h"); + } + }); + Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); + + + Dataset mapped = ds.map(new Function() { + @Override + public Integer call(String v) throws Exception { + return v.length(); + } + }, e.INT()); + Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); + + Dataset parMapped = ds.mapPartitions(new FlatMapFunction, String>() { + @Override + public Iterable call(Iterator it) throws Exception { + List ls = new LinkedList(); + while (it.hasNext()) { + ls.add(it.next().toUpperCase()); + } + return ls; + } + }, e.STRING()); + Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); + + Dataset flatMapped = ds.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String s) throws Exception { + List ls = new LinkedList(); + for (char c : s.toCharArray()) { + ls.add(String.valueOf(c)); + } + return ls; + } + }, e.STRING()); + Assert.assertEquals( + Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), + flatMapped.collectAsList()); + } + + @Test + public void testForeach() { + final Accumulator accum = jsc.accumulator(0); + List data = Arrays.asList("a", "b", "c"); + Dataset ds = context.createDataset(data, e.STRING()); + + ds.foreach(new VoidFunction() { + @Override + public void call(String s) throws Exception { + accum.add(1); + } + }); + Assert.assertEquals(3, accum.value().intValue()); + } + + @Test + public void testReduce() { + List data = Arrays.asList(1, 2, 3); + Dataset ds = context.createDataset(data, e.INT()); + + int reduced = ds.reduce(new Function2() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 + v2; + } + }); + Assert.assertEquals(6, reduced); + + int folded = ds.fold(1, new Function2() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 * v2; + } + }); + Assert.assertEquals(6, folded); + } + + @Test + public void testGroupBy() { + List data = Arrays.asList("a", "foo", "bar"); + Dataset ds = context.createDataset(data, e.STRING()); + GroupedDataset grouped = ds.groupBy(new Function() { + @Override + public Integer call(String v) throws Exception { + return v.length(); + } + }, e.INT()); + + Dataset mapped = grouped.mapGroups( + new Function2, Iterator>() { + @Override + public Iterator call(Integer key, Iterator data) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (data.hasNext()) { + sb.append(data.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + e.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + + List data2 = Arrays.asList(2, 6, 10); + Dataset ds2 = context.createDataset(data2, e.INT()); + GroupedDataset grouped2 = ds2.groupBy(new Function() { + @Override + public Integer call(Integer v) throws Exception { + return v / 2; + } + }, e.INT()); + + Dataset cogrouped = grouped.cogroup( + grouped2, + new Function3, Iterator, Iterator>() { + @Override + public Iterator call( + Integer key, + Iterator left, + Iterator right) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (left.hasNext()) { + sb.append(left.next()); + } + sb.append("#"); + while (right.hasNext()) { + sb.append(right.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + e.STRING()); + + Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); + } + + @Test + public void testGroupByColumn() { + List data = Arrays.asList("a", "foo", "bar"); + Dataset ds = context.createDataset(data, e.STRING()); + GroupedDataset grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); + + Dataset mapped = grouped.mapGroups( + new Function2, Iterator>() { + @Override + public Iterator call(Integer key, Iterator data) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (data.hasNext()) { + sb.append(data.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + e.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + } + + @Test + public void testSelect() { + List data = Arrays.asList(2, 6); + Dataset ds = context.createDataset(data, e.INT()); + + Dataset> selected = ds.select( + expr("value + 1").as(e.INT()), + col("value").cast("string").as(e.STRING())); + + Assert.assertEquals( + Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), + selected.collectAsList()); + } + + @Test + public void testSetOperation() { + List data = Arrays.asList("abc", "abc", "xyz"); + Dataset ds = context.createDataset(data, e.STRING()); + + Assert.assertEquals( + Arrays.asList("abc", "xyz"), + sort(ds.distinct().collectAsList().toArray(new String[0]))); + + List data2 = Arrays.asList("xyz", "foo", "foo"); + Dataset ds2 = context.createDataset(data2, e.STRING()); + + Dataset intersected = ds.intersect(ds2); + Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); + + Dataset unioned = ds.union(ds2); + Assert.assertEquals( + Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"), + sort(unioned.collectAsList().toArray(new String[0]))); + + Dataset subtracted = ds.subtract(ds2); + Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); + } + + private > List sort(T[] data) { + Arrays.sort(data); + return Arrays.asList(data); + } + + @Test + public void testJoin() { + List data = Arrays.asList(1, 2, 3); + Dataset ds = context.createDataset(data, e.INT()).as("a"); + List data2 = Arrays.asList(2, 3, 4); + Dataset ds2 = context.createDataset(data2, e.INT()).as("b"); + + Dataset> joined = + ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); + Assert.assertEquals( + Arrays.asList(tuple2(2, 2), tuple2(3, 3)), + joined.collectAsList()); + } + + @Test + public void testTupleEncoder() { + Encoder> encoder2 = e.tuple(e.INT(), e.STRING()); + List> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); + Dataset> ds2 = context.createDataset(data2, encoder2); + Assert.assertEquals(data2, ds2.collectAsList()); + + Encoder> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING()); + List> data3 = + Arrays.asList(new Tuple3(1, 2L, "a")); + Dataset> ds3 = context.createDataset(data3, encoder3); + Assert.assertEquals(data3, ds3.collectAsList()); + + Encoder> encoder4 = + e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING()); + List> data4 = + Arrays.asList(new Tuple4(1, "b", 2L, "a")); + Dataset> ds4 = context.createDataset(data4, encoder4); + Assert.assertEquals(data4, ds4.collectAsList()); + + Encoder> encoder5 = + e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN()); + List> data5 = + Arrays.asList(new Tuple5(1, "b", 2L, "a", true)); + Dataset> ds5 = + context.createDataset(data5, encoder5); + Assert.assertEquals(data5, ds5.collectAsList()); + } + + @Test + public void testNestedTupleEncoder() { + // test ((int, string), string) + Encoder, String>> encoder = + e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING()); + List, String>> data = + Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); + Dataset, String>> ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + + // test (int, (string, string, long)) + Encoder>> encoder2 = + e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG())); + List>> data2 = + Arrays.asList(tuple2(1, new Tuple3("a", "b", 3L))); + Dataset>> ds2 = + context.createDataset(data2, encoder2); + Assert.assertEquals(data2, ds2.collectAsList()); + + // test (int, ((string, long), string)) + Encoder, String>>> encoder3 = + e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING())); + List, String>>> data3 = + Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); + Dataset, String>>> ds3 = + context.createDataset(data3, encoder3); + Assert.assertEquals(data3, ds3.collectAsList()); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 32443557fb..e3b0346f85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -59,7 +59,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("foreach") { val ds = Seq(1, 2, 3).toDS() val acc = sparkContext.accumulator(0) - ds.foreach(acc +=) + ds.foreach(acc += _) assert(acc.value == 6) } -- cgit v1.2.3