aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-24 11:09:01 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-24 11:09:01 -0800
commit19530da6903fa59b051eec69b9c17e231c68454b (patch)
treed8dda4431af63527b7a08b9a3cc52cb21d94e17b /sql/catalyst
parent52bc25c8e26d4be250d8ff7864067528f4f98592 (diff)
downloadspark-19530da6903fa59b051eec69b9c17e231c68454b.tar.gz
spark-19530da6903fa59b051eec69b9c17e231c68454b.tar.bz2
spark-19530da6903fa59b051eec69b9c17e231c68454b.zip
[SPARK-11926][SQL] unify GetStructField and GetInternalRowField
Author: Wenchen Fan <wenchen@databricks.com> Closes #9909 from cloud-fan/get-struct.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala18
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala21
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala4
9 files changed, 21 insertions, 42 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 476becec4d..d133ad3f0d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection {
/** Returns the current path with a field at ordinal extracted. */
def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
- .map(p => GetInternalRowField(p, ordinal, dataType))
+ .map(p => GetStructField(p, ordinal))
.getOrElse(BoundReference(ordinal, dataType, false))
/** Returns the current path or `BoundReference`. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 6485bdfb30..1b2a8dc4c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
if (attribute.isDefined) {
// This target resolved to an attribute in child. It must be a struct. Expand it.
attribute.get.dataType match {
- case s: StructType => {
- s.fields.map( f => {
- val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get)
+ case s: StructType => s.zipWithIndex.map {
+ case (f, i) =>
+ val extract = GetStructField(attribute.get, i)
Alias(extract, target.get + "." + f.name)()
- })
}
+
case _ => {
throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
target.get + "`")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 7bc9aed0b2..0c10a56c55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -111,7 +111,7 @@ object ExpressionEncoder {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
- case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
+ case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index fa553e7c53..67518f52d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -220,7 +220,7 @@ object RowEncoder {
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, externalDataTypeFor(f.dataType)),
- constructorFor(GetInternalRowField(input, i, f.dataType)))
+ constructorFor(GetStructField(input, i)))
}
CreateExternalRow(convertedFields)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 540ed35006..169435a10e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] {
*/
def prettyString: String = {
transform {
- case a: AttributeReference => PrettyAttribute(a.name)
+ case a: AttributeReference => PrettyAttribute(a.name, a.dataType)
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index f871b737ff..10ce10aaf6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -51,7 +51,7 @@ object ExtractValue {
case (StructType(fields), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
- GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
+ GetStructField(child, ordinal, Some(fieldName))
case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
@@ -97,18 +97,18 @@ object ExtractValue {
* Returns the value of fields in the Struct `child`.
*
* No need to do type checking since it is handled by [[ExtractValue]].
- * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
+ *
+ * Note that we can pass in the field name directly to keep case preserving in `toString`.
+ * For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
*/
-case class GetStructField(child: Expression, field: StructField, ordinal: Int)
+case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
extends UnaryExpression {
- override def dataType: DataType = child.dataType match {
- case s: StructType => s(ordinal).dataType
- // This is a hack to avoid breaking existing code until we remove the need for the struct field
- case _ => field.dataType
- }
+ private lazy val field = child.dataType.asInstanceOf[StructType](ordinal)
+
+ override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
- override def toString: String = s"$child.${field.name}"
+ override def toString: String = s"$child.${name.getOrElse(field.name)}"
protected override def nullSafeEval(input: Any): Any =
input.asInstanceOf[InternalRow].get(ordinal, field.dataType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 00b7970bd1..26b6aca799 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -273,7 +273,8 @@ case class AttributeReference(
* A place holder used when printing expressions without debugging information such as the
* expression id or the unresolved indicator.
*/
-case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
+case class PrettyAttribute(name: String, dataType: DataType = NullType)
+ extends Attribute with Unevaluable {
override def toString: String = name
@@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
override def nullable: Boolean = throw new UnsupportedOperationException
- override def dataType: DataType = NullType
}
object VirtualColumn {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 4a1f419f0a..62d09f0f55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -517,27 +517,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
}
}
-case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)
- extends UnaryExpression {
-
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, eval => {
- s"""
- if ($eval.isNullAt($ordinal)) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
- }
- """
- })
- }
-}
-
/**
* Serializes an input object using a generic serializer (Kryo or Java).
* @param kryo if true, use Kryo. Otherwise, use Java.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index e60990aeb4..62fd47234b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
def getStructField(expr: Expression, fieldName: String): GetStructField = {
expr.dataType match {
case StructType(fields) =>
- val field = fields.find(_.name == fieldName).get
- GetStructField(expr, field, fields.indexOf(field))
+ val index = fields.indexWhere(_.name == fieldName)
+ GetStructField(expr, index)
}
}