From ec2b807212e568c9e98cd80746bcb61e02c7a98e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 11 Nov 2015 10:52:23 -0800 Subject: [SPARK-11564][SQL][FOLLOW-UP] clean up java tuple encoder We need to support custom classes like java beans and combine them into tuple, and it's very hard to do it with the TypeTag-based approach. We should keep only the compose-based way to create tuple encoder. This PR also move `Encoder` to `org.apache.spark.sql` Author: Wenchen Fan Closes #9567 from cloud-fan/java. --- .../main/scala/org/apache/spark/sql/Encoder.scala | 131 +++++++++++++++ .../spark/sql/catalyst/encoders/Encoder.scala | 182 --------------------- .../sql/catalyst/encoders/ExpressionEncoder.scala | 10 +- .../spark/sql/catalyst/encoders/package.scala | 3 +- .../catalyst/plans/logical/basicOperators.scala | 1 + .../main/scala/org/apache/spark/sql/Column.scala | 2 +- .../scala/org/apache/spark/sql/DataFrame.scala | 2 - .../org/apache/spark/sql/GroupedDataset.scala | 2 +- .../scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../aggregate/TypedAggregateExpression.scala | 3 +- .../apache/spark/sql/expressions/Aggregator.scala | 3 +- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../org/apache/spark/sql/JavaDatasetSuite.java | 78 ++++----- .../apache/spark/sql/DatasetAggregatorSuite.scala | 4 +- .../scala/org/apache/spark/sql/QueryTest.scala | 1 - 15 files changed, 189 insertions(+), 237 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala new file mode 100644 index 0000000000..1ff7340557 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.util.Utils + +import scala.reflect.ClassTag + +/** + * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. + * + * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking + * and reuse internal buffers to improve performance. + */ +trait Encoder[T] extends Serializable { + + /** Returns the schema of encoding this type of object as a Row. */ + def schema: StructType + + /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ + def clsTag: ClassTag[T] +} + +object Encoders { + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + + def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { + tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2)]] + } + + def tuple[T1, T2, T3]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + } + + def tuple[T1, T2, T3, T4]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + } + + def tuple[T1, T2, T3, T4, T5]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4], + enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] + } + + private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + assert(encoders.length > 1) + // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. + assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) + + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) + + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + + val extractExpressions = encoders.map { + case e if e.flat => e.extractExpressions.head + case other => CreateStruct(other.extractExpressions) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t: ObjectType, _) => + Invoke( + BoundReference(0, ObjectType(cls), nullable = true), + s"_${index + 1}", + t) + } + } + + val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.constructExpression.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + enc.constructExpression.transformUp { + case BoundReference(ordinal, dt, _) => + GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt) + } + } + } + + val constructExpression = + NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls)) + + new ExpressionEncoder[Any]( + schema, + false, + extractExpressions, + constructExpression, + ClassTag.apply(cls)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala deleted file mode 100644 index 6569b900fe..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag - -import org.apache.spark.util.Utils -import org.apache.spark.sql.types.{ObjectType, StructField, StructType} -import org.apache.spark.sql.catalyst.expressions._ - -/** - * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. - * - * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking - * and reuse internal buffers to improve performance. - */ -trait Encoder[T] extends Serializable { - - /** Returns the schema of encoding this type of object as a Row. */ - def schema: StructType - - /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ - def clsTag: ClassTag[T] -} - -object Encoder { - import scala.reflect.runtime.universe._ - - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) - - def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { - tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2)]] - } - - def tuple[T1, T2, T3]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { - tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] - } - - def tuple[T1, T2, T3, T4]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3], - enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { - tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] - } - - def tuple[T1, T2, T3, T4, T5]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3], - enc4: Encoder[T4], - enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { - tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] - } - - private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { - assert(encoders.length > 1) - // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. - assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) - - val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) - }) - - val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - - val extractExpressions = encoders.map { - case e if e.flat => e.extractExpressions.head - case other => CreateStruct(other.extractExpressions) - }.zipWithIndex.map { case (expr, index) => - expr.transformUp { - case BoundReference(0, t: ObjectType, _) => - Invoke( - BoundReference(0, ObjectType(cls), nullable = true), - s"_${index + 1}", - t) - } - } - - val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => - if (enc.flat) { - enc.constructExpression.transform { - case b: BoundReference => b.copy(ordinal = index) - } - } else { - enc.constructExpression.transformUp { - case BoundReference(ordinal, dt, _) => - GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt) - } - } - } - - val constructExpression = - NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls)) - - new ExpressionEncoder[Any]( - schema, - false, - extractExpressions, - constructExpression, - ClassTag.apply(cls)) - } - - def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] - - private def getTypeTag[T](c: Class[T]): TypeTag[T] = { - import scala.reflect.api - - // val mirror = runtimeMirror(c.getClassLoader) - val mirror = rootMirror - val sym = mirror.staticClass(c.getName) - val tpe = sym.selfType - TypeTag(mirror, new api.TypeCreator { - def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) = - if (m eq mirror) tpe.asInstanceOf[U # Type] - else throw new IllegalArgumentException( - s"Type tag defined in $mirror cannot be migrated to other mirrors.") - }) - } - - def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - ExpressionEncoder[(T1, T2)]() - } - - def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - implicit val typeTag3 = getTypeTag(c3) - ExpressionEncoder[(T1, T2, T3)]() - } - - def forTuple[T1, T2, T3, T4]( - c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - implicit val typeTag3 = getTypeTag(c3) - implicit val typeTag4 = getTypeTag(c4) - ExpressionEncoder[(T1, T2, T3, T4)]() - } - - def forTuple[T1, T2, T3, T4, T5]( - c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5]) - : Encoder[(T1, T2, T3, T4, T5)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - implicit val typeTag3 = getTypeTag(c3) - implicit val typeTag4 = getTypeTag(c4) - implicit val typeTag5 = getTypeTag(c5) - ExpressionEncoder[(T1, T2, T3, T4, T5)]() - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 005c0627f5..294afde534 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,18 +17,18 @@ package org.apache.spark.sql.catalyst.encoders -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.util.Utils - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} +import org.apache.spark.util.Utils +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** * A factory for constructing encoders that convert objects and primitves to and from the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index d4642a5006..2c35adca9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.Encoder + package object encoders { private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { case e: ExpressionEncoder[A] => e case _ => sys.error(s"Only expression encoders are supported today") } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 764f8aaebd..597f03e752 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d26b6c3579..f0f275e91f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 691b476fff..a492099b93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -23,7 +23,6 @@ import java.util.Properties import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonFactory import org.apache.commons.lang3.StringUtils @@ -35,7 +34,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index db61499229..61e2a95450 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1cf1e30f96..cd1fdc4edb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index b5a87c56e6..dfcbac8687 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials import org.apache.spark.Logging +import org.apache.spark.sql.Encoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 2aa5a7d540..360c9a5bc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a59d738010..ab49ed4b5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 2da63d1b96..33d8388f61 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -30,8 +30,8 @@ import org.apache.spark.Accumulator; import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.catalyst.encoders.Encoder; -import org.apache.spark.sql.catalyst.encoders.Encoder$; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.GroupedDataset; import org.apache.spark.sql.test.TestSQLContext; @@ -41,7 +41,6 @@ import static org.apache.spark.sql.functions.*; public class JavaDatasetSuite implements Serializable { private transient JavaSparkContext jsc; private transient TestSQLContext context; - private transient Encoder$ e = Encoder$.MODULE$; @Before public void setUp() { @@ -66,7 +65,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCollect() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, e.STRING()); + Dataset ds = context.createDataset(data, Encoders.STRING()); List collected = ds.collectAsList(); Assert.assertEquals(Arrays.asList("hello", "world"), collected); } @@ -74,7 +73,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTake() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, e.STRING()); + Dataset ds = context.createDataset(data, Encoders.STRING()); List collected = ds.takeAsList(1); Assert.assertEquals(Arrays.asList("hello"), collected); } @@ -82,7 +81,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCommonOperation() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, e.STRING()); + Dataset ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); Dataset filtered = ds.filter(new FilterFunction() { @@ -99,7 +98,7 @@ public class JavaDatasetSuite implements Serializable { public Integer call(String v) throws Exception { return v.length(); } - }, e.INT()); + }, Encoders.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); Dataset parMapped = ds.mapPartitions(new MapPartitionsFunction() { @@ -111,7 +110,7 @@ public class JavaDatasetSuite implements Serializable { } return ls; } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); Dataset flatMapped = ds.flatMap(new FlatMapFunction() { @@ -123,7 +122,7 @@ public class JavaDatasetSuite implements Serializable { } return ls; } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals( Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), flatMapped.collectAsList()); @@ -133,7 +132,7 @@ public class JavaDatasetSuite implements Serializable { public void testForeach() { final Accumulator accum = jsc.accumulator(0); List data = Arrays.asList("a", "b", "c"); - Dataset ds = context.createDataset(data, e.STRING()); + Dataset ds = context.createDataset(data, Encoders.STRING()); ds.foreach(new ForeachFunction() { @Override @@ -147,7 +146,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testReduce() { List data = Arrays.asList(1, 2, 3); - Dataset ds = context.createDataset(data, e.INT()); + Dataset ds = context.createDataset(data, Encoders.INT()); int reduced = ds.reduce(new ReduceFunction() { @Override @@ -161,13 +160,13 @@ public class JavaDatasetSuite implements Serializable { @Test public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); - Dataset ds = context.createDataset(data, e.STRING()); + Dataset ds = context.createDataset(data, Encoders.STRING()); GroupedDataset grouped = ds.groupBy(new MapFunction() { @Override public Integer call(String v) throws Exception { return v.length(); } - }, e.INT()); + }, Encoders.INT()); Dataset mapped = grouped.map(new MapGroupFunction() { @Override @@ -178,7 +177,7 @@ public class JavaDatasetSuite implements Serializable { } return sb.toString(); } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); @@ -193,27 +192,27 @@ public class JavaDatasetSuite implements Serializable { return Collections.singletonList(sb.toString()); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); List data2 = Arrays.asList(2, 6, 10); - Dataset ds2 = context.createDataset(data2, e.INT()); + Dataset ds2 = context.createDataset(data2, Encoders.INT()); GroupedDataset grouped2 = ds2.groupBy(new MapFunction() { @Override public Integer call(Integer v) throws Exception { return v / 2; } - }, e.INT()); + }, Encoders.INT()); Dataset cogrouped = grouped.cogroup( grouped2, new CoGroupFunction() { @Override public Iterable call( - Integer key, - Iterator left, - Iterator right) throws Exception { + Integer key, + Iterator left, + Iterator right) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (left.hasNext()) { sb.append(left.next()); @@ -225,7 +224,7 @@ public class JavaDatasetSuite implements Serializable { return Collections.singletonList(sb.toString()); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); } @@ -233,8 +232,9 @@ public class JavaDatasetSuite implements Serializable { @Test public void testGroupByColumn() { List data = Arrays.asList("a", "foo", "bar"); - Dataset ds = context.createDataset(data, e.STRING()); - GroupedDataset grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); + Dataset ds = context.createDataset(data, Encoders.STRING()); + GroupedDataset grouped = + ds.groupBy(length(col("value"))).asKey(Encoders.INT()); Dataset mapped = grouped.map( new MapGroupFunction() { @@ -247,7 +247,7 @@ public class JavaDatasetSuite implements Serializable { return sb.toString(); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); } @@ -255,11 +255,11 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSelect() { List data = Arrays.asList(2, 6); - Dataset ds = context.createDataset(data, e.INT()); + Dataset ds = context.createDataset(data, Encoders.INT()); Dataset> selected = ds.select( expr("value + 1"), - col("value").cast("string")).as(e.tuple(e.INT(), e.STRING())); + col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING())); Assert.assertEquals( Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), @@ -269,14 +269,14 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSetOperation() { List data = Arrays.asList("abc", "abc", "xyz"); - Dataset ds = context.createDataset(data, e.STRING()); + Dataset ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals( Arrays.asList("abc", "xyz"), sort(ds.distinct().collectAsList().toArray(new String[0]))); List data2 = Arrays.asList("xyz", "foo", "foo"); - Dataset ds2 = context.createDataset(data2, e.STRING()); + Dataset ds2 = context.createDataset(data2, Encoders.STRING()); Dataset intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); @@ -298,9 +298,9 @@ public class JavaDatasetSuite implements Serializable { @Test public void testJoin() { List data = Arrays.asList(1, 2, 3); - Dataset ds = context.createDataset(data, e.INT()).as("a"); + Dataset ds = context.createDataset(data, Encoders.INT()).as("a"); List data2 = Arrays.asList(2, 3, 4); - Dataset ds2 = context.createDataset(data2, e.INT()).as("b"); + Dataset ds2 = context.createDataset(data2, Encoders.INT()).as("b"); Dataset> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); @@ -311,26 +311,28 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTupleEncoder() { - Encoder> encoder2 = e.tuple(e.INT(), e.STRING()); + Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); List> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); Dataset> ds2 = context.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); - Encoder> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING()); + Encoder> encoder3 = + Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List> data3 = Arrays.asList(new Tuple3(1, 2L, "a")); Dataset> ds3 = context.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder> encoder4 = - e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING()); + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List> data4 = Arrays.asList(new Tuple4(1, "b", 2L, "a")); Dataset> ds4 = context.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); Encoder> encoder5 = - e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN()); + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(), + Encoders.BOOLEAN()); List> data5 = Arrays.asList(new Tuple5(1, "b", 2L, "a", true)); Dataset> ds5 = @@ -342,7 +344,7 @@ public class JavaDatasetSuite implements Serializable { public void testNestedTupleEncoder() { // test ((int, string), string) Encoder, String>> encoder = - e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING()); + Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); List, String>> data = Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); Dataset, String>> ds = context.createDataset(data, encoder); @@ -350,7 +352,8 @@ public class JavaDatasetSuite implements Serializable { // test (int, (string, string, long)) Encoder>> encoder2 = - e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG())); + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG())); List>> data2 = Arrays.asList(tuple2(1, new Tuple3("a", "b", 3L))); Dataset>> ds2 = @@ -359,7 +362,8 @@ public class JavaDatasetSuite implements Serializable { // test (int, ((string, long), string)) Encoder, String>>> encoder3 = - e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING())); + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING())); List, String>>> data3 = Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Dataset, String>>> ds3 = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index d4f0ab76cf..378cd36527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders.Encoder -import org.apache.spark.sql.functions._ import scala.language.postfixOps import org.apache.spark.sql.test.SharedSQLContext - +import org.apache.spark.sql.functions._ import org.apache.spark.sql.expressions.Aggregator /** An `Aggregator` that adds up any numeric type returned by the given function. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3c174efe73..7a8b7ae5bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.catalyst.encoders.Encoder abstract class QueryTest extends PlanTest { -- cgit v1.2.3