aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-16 12:45:34 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-16 12:45:34 -0800
commitb1a9662623951079e80bd7498e064c4cae4977e9 (patch)
tree5c846f13b701d60bb068a1489981b6be07a4b14f /sql
parent24477d2705bcf2a851acc241deb8376c5450dc73 (diff)
downloadspark-b1a9662623951079e80bd7498e064c4cae4977e9.tar.gz
spark-b1a9662623951079e80bd7498e064c4cae4977e9.tar.bz2
spark-b1a9662623951079e80bd7498e064c4cae4977e9.zip
[SPARK-11754][SQL] consolidate `ExpressionEncoder.tuple` and `Encoders.tuple`
These 2 are very similar, we can consolidate them into one. Also add tests for it and fix a bug. Author: Wenchen Fan <wenchen@databricks.com> Closes #9729 from cloud-fan/tuple.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala95
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala104
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala29
3 files changed, 108 insertions, 120 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 5f619d6c33..c8b017e251 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -19,10 +19,8 @@ package org.apache.spark.sql
import scala.reflect.ClassTag
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
+import org.apache.spark.sql.types.StructType
/**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
@@ -49,83 +47,34 @@ object Encoders {
def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
- def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = {
- tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]]))
- .asInstanceOf[ExpressionEncoder[(T1, T2)]]
+ def tuple[T1, T2](
+ e1: Encoder[T1],
+ e2: Encoder[T2]): Encoder[(T1, T2)] = {
+ ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
}
def tuple[T1, T2, T3](
- enc1: Encoder[T1],
- enc2: Encoder[T2],
- enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
- tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]]))
- .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
+ e1: Encoder[T1],
+ e2: Encoder[T2],
+ e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
+ ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
}
def tuple[T1, T2, T3, T4](
- enc1: Encoder[T1],
- enc2: Encoder[T2],
- enc3: Encoder[T3],
- enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
- tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]]))
- .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
+ e1: Encoder[T1],
+ e2: Encoder[T2],
+ e3: Encoder[T3],
+ e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
+ ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4))
}
def tuple[T1, T2, T3, T4, T5](
- enc1: Encoder[T1],
- enc2: Encoder[T2],
- enc3: Encoder[T3],
- enc4: Encoder[T4],
- enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
- tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]]))
- .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
- }
-
- private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
- assert(encoders.length > 1)
- // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
- assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))
-
- val schema = StructType(encoders.zipWithIndex.map {
- case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
- })
-
- val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
-
- val extractExpressions = encoders.map {
- case e if e.flat => e.toRowExpressions.head
- case other => CreateStruct(other.toRowExpressions)
- }.zipWithIndex.map { case (expr, index) =>
- expr.transformUp {
- case BoundReference(0, t: ObjectType, _) =>
- Invoke(
- BoundReference(0, ObjectType(cls), nullable = true),
- s"_${index + 1}",
- t)
- }
- }
-
- val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
- if (enc.flat) {
- enc.fromRowExpression.transform {
- case b: BoundReference => b.copy(ordinal = index)
- }
- } else {
- enc.fromRowExpression.transformUp {
- case BoundReference(ordinal, dt, _) =>
- GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
- }
- }
- }
-
- val constructExpression =
- NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls))
-
- new ExpressionEncoder[Any](
- schema,
- flat = false,
- extractExpressions,
- constructExpression,
- ClassTag(cls))
+ e1: Encoder[T1],
+ e2: Encoder[T2],
+ e3: Encoder[T3],
+ e4: Encoder[T4],
+ e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
+ ExpressionEncoder.tuple(
+ encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0d3e4aafb0..9a1a8f5cbb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -67,47 +67,77 @@ object ExpressionEncoder {
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
encoders.foreach(_.assertUnresolved())
- val schema =
- StructType(
- encoders.zipWithIndex.map {
- case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
- })
+ val schema = StructType(encoders.zipWithIndex.map {
+ case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
+ })
+
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
- // Rebind the encoders to the nested schema.
- val newConstructExpressions = encoders.zipWithIndex.map {
- case (e, i) if !e.flat => e.nested(i).fromRowExpression
- case (e, i) => e.shift(i).fromRowExpression
+ val toRowExpressions = encoders.map {
+ case e if e.flat => e.toRowExpressions.head
+ case other => CreateStruct(other.toRowExpressions)
+ }.zipWithIndex.map { case (expr, index) =>
+ expr.transformUp {
+ case BoundReference(0, t, _) =>
+ Invoke(
+ BoundReference(0, ObjectType(cls), nullable = true),
+ s"_${index + 1}",
+ t)
+ }
}
- val constructExpression =
- NewInstance(cls, newConstructExpressions, false, ObjectType(cls))
-
- val input = BoundReference(0, ObjectType(cls), false)
- val extractExpressions = encoders.zipWithIndex.map {
- case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp {
- case b: BoundReference =>
- Invoke(input, s"_${i + 1}", b.dataType, Nil)
- }))
- case (e, i) => e.toRowExpressions.head transformUp {
- case b: BoundReference =>
- Invoke(input, s"_${i + 1}", b.dataType, Nil)
+ val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) =>
+ if (enc.flat) {
+ enc.fromRowExpression.transform {
+ case b: BoundReference => b.copy(ordinal = index)
+ }
+ } else {
+ val input = BoundReference(index, enc.schema, nullable = true)
+ enc.fromRowExpression.transformUp {
+ case UnresolvedAttribute(nameParts) =>
+ assert(nameParts.length == 1)
+ UnresolvedExtractValue(input, Literal(nameParts.head))
+ case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
+ }
}
}
+ val fromRowExpression =
+ NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls))
+
new ExpressionEncoder[Any](
schema,
- false,
- extractExpressions,
- constructExpression,
- ClassTag.apply(cls))
+ flat = false,
+ toRowExpressions,
+ fromRowExpression,
+ ClassTag(cls))
}
- /** A helper for producing encoders of Tuple2 from other encoders. */
def tuple[T1, T2](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
- tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]]
+ tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]
+
+ def tuple[T1, T2, T3](
+ e1: ExpressionEncoder[T1],
+ e2: ExpressionEncoder[T2],
+ e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
+ tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
+
+ def tuple[T1, T2, T3, T4](
+ e1: ExpressionEncoder[T1],
+ e2: ExpressionEncoder[T2],
+ e3: ExpressionEncoder[T3],
+ e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
+ tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
+
+ def tuple[T1, T2, T3, T4, T5](
+ e1: ExpressionEncoder[T1],
+ e2: ExpressionEncoder[T2],
+ e3: ExpressionEncoder[T3],
+ e4: ExpressionEncoder[T4],
+ e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
+ tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
}
/**
@@ -208,26 +238,6 @@ case class ExpressionEncoder[T](
})
}
- /**
- * Returns a copy of this encoder where the expressions used to create an object given an
- * input row have been modified to pull the object out from a nested struct, instead of the
- * top level fields.
- */
- private def nested(i: Int): ExpressionEncoder[T] = {
- // We don't always know our input type at this point since it might be unresolved.
- // We fill in null and it will get unbound to the actual attribute at this position.
- val input = BoundReference(i, NullType, nullable = true)
- copy(fromRowExpression = fromRowExpression transformUp {
- case u: Attribute =>
- UnresolvedExtractValue(input, Literal(u.name))
- case b: BoundReference =>
- GetStructField(
- input,
- StructField(s"i[${b.ordinal}]", b.dataType),
- b.ordinal)
- })
- }
-
protected val attrs = toRowExpressions.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
index fda978e705..bc539d62c5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -117,6 +117,35 @@ class ProductEncoderSuite extends ExpressionEncoderSuite {
productTest(("Seq[Seq[(Int, Int)]]",
Seq(Seq((1, 2)))))
+ encodeDecodeTest(
+ 1 -> 10L,
+ ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]),
+ "tuple with 2 flat encoders")
+
+ encodeDecodeTest(
+ (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)),
+ ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]),
+ "tuple with 2 product encoders")
+
+ encodeDecodeTest(
+ (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3),
+ ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]),
+ "tuple with flat encoder and product encoder")
+
+ encodeDecodeTest(
+ (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)),
+ ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]),
+ "tuple with product encoder and flat encoder")
+
+ encodeDecodeTest(
+ (1, (10, 100L)),
+ {
+ val intEnc = FlatEncoder[Int]
+ val longEnc = FlatEncoder[Long]
+ ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
+ },
+ "nested tuple encoder")
+
private def productTest[T <: Product : TypeTag](input: T): Unit = {
encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
}