aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala12
2 files changed, 51 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
new file mode 100644
index 0000000000..8364379644
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+/**
+ * Builds a map that is keyed by an Attribute's expression id. Using the expression id allows values
+ * to be looked up even when the attributes used differ cosmetically (i.e., the capitalization
+ * of the name, or the expected nullability).
+ */
+object AttributeMap {
+ def apply[A](kvs: Seq[(Attribute, A)]) =
+ new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap)
+}
+
+class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
+ extends Map[Attribute, A] with Serializable {
+
+ override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
+
+ override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] =
+ (baseMap.map(_._2) + kv).toMap
+
+ override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator
+
+ override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 54c6baf1af..fa80b07f8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -38,12 +38,20 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
}
object BindReferences extends Logging {
- def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = {
+
+ def bindReference[A <: Expression](
+ expression: A,
+ input: Seq[Attribute],
+ allowFailures: Boolean = false): A = {
expression.transform { case a: AttributeReference =>
attachTree(a, "Binding attribute") {
val ordinal = input.indexWhere(_.exprId == a.exprId)
if (ordinal == -1) {
- sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
+ if (allowFailures) {
+ a
+ } else {
+ sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
+ }
} else {
BoundReference(ordinal, a.dataType, a.nullable)
}