diff options
author | Wenchen Fan <wenchen@databricks.com> | 2015-11-06 15:37:07 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-11-06 15:37:07 -0800 |
commit | 7e9a9e603abce8689938bdd62d04b29299644aa4 (patch) | |
tree | cff62e9cde2e44aae8a2b5a8a2ae536bb67b9f38 /sql/catalyst/src | |
parent | f6680cdc5d2912dea9768ef5c3e2cc101b06daf8 (diff) | |
download | spark-7e9a9e603abce8689938bdd62d04b29299644aa4.tar.gz spark-7e9a9e603abce8689938bdd62d04b29299644aa4.tar.bz2 spark-7e9a9e603abce8689938bdd62d04b29299644aa4.zip |
[SPARK-11269][SQL] Java API support & test cases for Dataset
This simply brings https://github.com/apache/spark/pull/9358 up-to-date.
Author: Wenchen Fan <wenchen@databricks.com>
Author: Reynold Xin <rxin@databricks.com>
Closes #9528 from rxin/dataset-java.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala | 123 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala | 21 |
2 files changed, 141 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index 329a132d3d..f05e18288d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.encoders - - import scala.reflect.ClassTag -import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils +import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType} +import org.apache.spark.sql.catalyst.expressions._ /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -37,3 +37,120 @@ trait Encoder[T] extends Serializable { /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ def clsTag: ClassTag[T] } + +object Encoder { + import scala.reflect.runtime.universe._ + + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + + def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { + tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2)]] + } + + def tuple[T1, T2, T3]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + } + + def tuple[T1, T2, T3, T4]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + } + + def tuple[T1, T2, T3, T4, T5]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4], + enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] + } + + private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + assert(encoders.length > 1) + // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. + assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) + + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) + + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + + val extractExpressions = encoders.map { + case e if e.flat => e.extractExpressions.head + case other => CreateStruct(other.extractExpressions) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t: ObjectType, _) => + Invoke( + BoundReference(0, ObjectType(cls), true), + s"_${index + 1}", + t) + } + } + + val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.constructExpression.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + enc.constructExpression.transformUp { + case BoundReference(ordinal, dt, _) => + GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt) + } + } + } + + val constructExpression = + NewInstance(cls, constructExpressions, false, ObjectType(cls)) + + new ExpressionEncoder[Any]( + schema, + false, + extractExpressions, + constructExpression, + ClassTag.apply(cls)) + } + + + def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] + + private def getTypeTag[T](c: Class[T]): TypeTag[T] = { + import scala.reflect.api + + // val mirror = runtimeMirror(c.getClassLoader) + val mirror = rootMirror + val sym = mirror.staticClass(c.getName) + val tpe = sym.selfType + TypeTag(mirror, new api.TypeCreator { + def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) = + if (m eq mirror) tpe.asInstanceOf[U # Type] + else throw new IllegalArgumentException( + s"Type tag defined in $mirror cannot be migrated to other mirrors.") + }) + } + + def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + ExpressionEncoder[(T1, T2)]() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 8185528976..4f58464221 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -491,3 +491,24 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);" } } + +case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType) + extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val row = child.gen(ctx) + s""" + ${row.code} + final boolean ${ev.isNull} = ${row.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)}; + } + """ + } +} |