aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-02-09 16:39:34 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-09 16:39:34 -0800
commit0ee53ebce9944722e76b2b28fae79d9956be9f17 (patch)
tree2607124e553ce958e1b60b460e949727532104c4 /sql/catalyst
parent2a36292534a1e9f7a501e88f69bfc3a09fb62cb3 (diff)
downloadspark-0ee53ebce9944722e76b2b28fae79d9956be9f17.tar.gz
spark-0ee53ebce9944722e76b2b28fae79d9956be9f17.tar.bz2
spark-0ee53ebce9944722e76b2b28fae79d9956be9f17.zip
[SPARK-2096][SQL] support dot notation on array of struct
~~The rule is simple: If you want `a.b` work, then `a` must be some level of nested array of struct(level 0 means just a StructType). And the result of `a.b` is same level of nested array of b-type. An optimization is: the resolve chain looks like `Attribute -> GetItem -> GetField -> GetField ...`, so we could transmit the nested array information between `GetItem` and `GetField` to avoid repeated computation of `innerDataType` and `containsNullList` of that nested array.~~ marmbrus Could you take a look? to evaluate `a.b`, if `a` is array of struct, then `a.b` means get field `b` on each element of `a`, and return a result of array. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #2405 from cloud-fan/nested-array-dot and squashes the following commits: 08a228a [Wenchen Fan] support dot notation on array of struct
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala2
4 files changed, 51 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0b59ed1739..fb2ff014ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType}
/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -311,18 +310,25 @@ class Analyzer(catalog: Catalog,
* desired fields are found.
*/
protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
+ def findField(fields: Array[StructField]): Int = {
+ val checkField = (f: StructField) => resolver(f.name, fieldName)
+ val ordinal = fields.indexWhere(checkField)
+ if (ordinal == -1) {
+ sys.error(
+ s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
+ } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
+ sys.error(s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
+ } else {
+ ordinal
+ }
+ }
expr.dataType match {
case StructType(fields) =>
- val actualField = fields.filter(f => resolver(f.name, fieldName))
- if (actualField.length == 0) {
- sys.error(
- s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
- } else if (actualField.length == 1) {
- val field = actualField(0)
- GetField(expr, field, fields.indexOf(field))
- } else {
- sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
- }
+ val ordinal = findField(fields)
+ StructGetField(expr, fields(ordinal), ordinal)
+ case ArrayType(StructType(fields), containsNull) =>
+ val ordinal = findField(fields)
+ ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index 66e2e5c4ba..68051a2a20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -70,22 +70,48 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
}
}
+
+trait GetField extends UnaryExpression {
+ self: Product =>
+
+ type EvaluatedType = Any
+ override def foldable = child.foldable
+ override def toString = s"$child.${field.name}"
+
+ def field: StructField
+}
+
/**
* Returns the value of fields in the Struct `child`.
*/
-case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
- type EvaluatedType = Any
+case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {
def dataType = field.dataType
override def nullable = child.nullable || field.nullable
- override def foldable = child.foldable
override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}
+}
- override def toString = s"$child.${field.name}"
+/**
+ * Returns the array of value of fields in the Array of Struct `child`.
+ */
+case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
+ extends GetField {
+
+ def dataType = ArrayType(field.dataType, containsNull)
+ override def nullable = child.nullable
+
+ override def eval(input: Row): Any = {
+ val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
+ if (baseValue == null) null else {
+ baseValue.map { row =>
+ if (row == null) null else row(ordinal)
+ }
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fd58b9681e..0da081ed1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -209,7 +209,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
- case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
+ case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType)
+ case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 7cf6c80194..dcfd8b28cb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -851,7 +851,7 @@ class ExpressionEvaluationSuite extends FunSuite {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
- GetField(expr, field, fields.indexOf(field))
+ StructGetField(expr, field, fields.indexOf(field))
}
}