aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-06 15:37:07 -0800
committerReynold Xin <rxin@databricks.com>2015-11-06 15:37:07 -0800
commit7e9a9e603abce8689938bdd62d04b29299644aa4 (patch)
treecff62e9cde2e44aae8a2b5a8a2ae536bb67b9f38 /sql/catalyst/src
parentf6680cdc5d2912dea9768ef5c3e2cc101b06daf8 (diff)
downloadspark-7e9a9e603abce8689938bdd62d04b29299644aa4.tar.gz
spark-7e9a9e603abce8689938bdd62d04b29299644aa4.tar.bz2
spark-7e9a9e603abce8689938bdd62d04b29299644aa4.zip
[SPARK-11269][SQL] Java API support & test cases for Dataset
This simply brings https://github.com/apache/spark/pull/9358 up-to-date. Author: Wenchen Fan <wenchen@databricks.com> Author: Reynold Xin <rxin@databricks.com> Closes #9528 from rxin/dataset-java.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala123
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala21
2 files changed, 141 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index 329a132d3d..f05e18288d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.catalyst.encoders
-
-
import scala.reflect.ClassTag
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType}
+import org.apache.spark.sql.catalyst.expressions._
/**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
@@ -37,3 +37,120 @@ trait Encoder[T] extends Serializable {
/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
def clsTag: ClassTag[T]
}
+
+object Encoder {
+ import scala.reflect.runtime.universe._
+
+ def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
+ def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
+ def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
+ def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true)
+ def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true)
+ def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true)
+ def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
+ def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
+
+ def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = {
+ tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]]))
+ .asInstanceOf[ExpressionEncoder[(T1, T2)]]
+ }
+
+ def tuple[T1, T2, T3](
+ enc1: Encoder[T1],
+ enc2: Encoder[T2],
+ enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
+ tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]]))
+ .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
+ }
+
+ def tuple[T1, T2, T3, T4](
+ enc1: Encoder[T1],
+ enc2: Encoder[T2],
+ enc3: Encoder[T3],
+ enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
+ tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]]))
+ .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
+ }
+
+ def tuple[T1, T2, T3, T4, T5](
+ enc1: Encoder[T1],
+ enc2: Encoder[T2],
+ enc3: Encoder[T3],
+ enc4: Encoder[T4],
+ enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
+ tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]]))
+ .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
+ }
+
+ private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+ assert(encoders.length > 1)
+ // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
+ assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))
+
+ val schema = StructType(encoders.zipWithIndex.map {
+ case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
+ })
+
+ val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
+
+ val extractExpressions = encoders.map {
+ case e if e.flat => e.extractExpressions.head
+ case other => CreateStruct(other.extractExpressions)
+ }.zipWithIndex.map { case (expr, index) =>
+ expr.transformUp {
+ case BoundReference(0, t: ObjectType, _) =>
+ Invoke(
+ BoundReference(0, ObjectType(cls), true),
+ s"_${index + 1}",
+ t)
+ }
+ }
+
+ val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
+ if (enc.flat) {
+ enc.constructExpression.transform {
+ case b: BoundReference => b.copy(ordinal = index)
+ }
+ } else {
+ enc.constructExpression.transformUp {
+ case BoundReference(ordinal, dt, _) =>
+ GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt)
+ }
+ }
+ }
+
+ val constructExpression =
+ NewInstance(cls, constructExpressions, false, ObjectType(cls))
+
+ new ExpressionEncoder[Any](
+ schema,
+ false,
+ extractExpressions,
+ constructExpression,
+ ClassTag.apply(cls))
+ }
+
+
+ def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
+
+ private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
+ import scala.reflect.api
+
+ // val mirror = runtimeMirror(c.getClassLoader)
+ val mirror = rootMirror
+ val sym = mirror.staticClass(c.getName)
+ val tpe = sym.selfType
+ TypeTag(mirror, new api.TypeCreator {
+ def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
+ if (m eq mirror) tpe.asInstanceOf[U # Type]
+ else throw new IllegalArgumentException(
+ s"Type tag defined in $mirror cannot be migrated to other mirrors.")
+ })
+ }
+
+ def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
+ implicit val typeTag1 = getTypeTag(c1)
+ implicit val typeTag2 = getTypeTag(c2)
+ ExpressionEncoder[(T1, T2)]()
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 8185528976..4f58464221 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -491,3 +491,24 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
}
}
+
+case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)
+ extends UnaryExpression {
+
+ override def nullable: Boolean = true
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val row = child.gen(ctx)
+ s"""
+ ${row.code}
+ final boolean ${ev.isNull} = ${row.isNull};
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)};
+ }
+ """
+ }
+}