aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala58
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala61
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala32
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala2
5 files changed, 154 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 0bde48ce57..3f9227a8ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.xml._
import org.apache.spark.sql.catalyst.util.StringKeyHashMap
@@ -301,6 +302,7 @@ object FunctionRegistry {
expression[UnBase64]("unbase64"),
expression[Unhex]("unhex"),
expression[Upper]("upper"),
+ expression[XPathBoolean]("xpath_boolean"),
// datetime functions
expression[AddMonths]("add_months"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala
new file mode 100644
index 0000000000..2a5256c7f5
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.xml
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(xml, xpath) - Evaluates a boolean xpath expression.",
+ extended = "> SELECT _FUNC_('<a><b>1</b></a>','a/b');\ntrue")
+case class XPathBoolean(xml: Expression, path: Expression)
+ extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
+
+ @transient private lazy val xpathUtil = new UDFXPathUtil
+
+ // If the path is a constant, cache the path string so that we don't need to convert path
+ // from UTF8String to String for every row.
+ @transient lazy val pathLiteral: String = path match {
+ case Literal(str: UTF8String, _) => str.toString
+ case _ => null
+ }
+
+ override def prettyName: String = "xpath_boolean"
+
+ override def dataType: DataType = BooleanType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
+
+ override def left: Expression = xml
+ override def right: Expression = path
+
+ override protected def nullSafeEval(xml: Any, path: Any): Any = {
+ val xmlString = xml.asInstanceOf[UTF8String].toString
+ if (pathLiteral ne null) {
+ xpathUtil.evalBoolean(xmlString, pathLiteral)
+ } else {
+ xpathUtil.evalBoolean(xmlString, path.asInstanceOf[UTF8String].toString)
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala
new file mode 100644
index 0000000000..f7c65c667e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.xml
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal}
+import org.apache.spark.sql.types.StringType
+
+/**
+ * Test suite for various xpath functions.
+ */
+class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ private def testBoolean[T](xml: String, path: String, expected: T): Unit = {
+ checkEvaluation(
+ XPathBoolean(Literal.create(xml, StringType), Literal.create(path, StringType)),
+ expected)
+ }
+
+ test("xpath_boolean") {
+ testBoolean("<a><b>b</b></a>", "a/b", true)
+ testBoolean("<a><b>b</b></a>", "a/c", false)
+ testBoolean("<a><b>b</b></a>", "a/b = \"b\"", true)
+ testBoolean("<a><b>b</b></a>", "a/b = \"c\"", false)
+ testBoolean("<a><b>10</b></a>", "a/b < 10", false)
+ testBoolean("<a><b>10</b></a>", "a/b = 10", true)
+
+ // null input
+ testBoolean(null, null, null)
+ testBoolean(null, "a", null)
+ testBoolean("<a><b>10</b></a>", null, null)
+
+ // exception handling for invalid input
+ intercept[Exception] {
+ testBoolean("<a>/a>", "a", null)
+ }
+ }
+
+ test("xpath_boolean path cache invalidation") {
+ // This is a test to ensure the expression is not reusing the path for different strings
+ val expr = XPathBoolean(Literal("<a><b>b</b></a>"), 'path.string.at(0))
+ checkEvaluation(expr, true, create_row("a/b"))
+ checkEvaluation(expr, false, create_row("a/c"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala
new file mode 100644
index 0000000000..532d48cc26
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+/**
+ * End-to-end tests for XML expressions.
+ */
+class XmlFunctionsSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("xpath_boolean") {
+ val df = Seq("<a><b>b</b></a>" -> "a/b").toDF("xml", "path")
+ checkAnswer(df.selectExpr("xpath_boolean(xml, path)"), Row(true))
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 2589b9d4a0..fa560a044b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -241,7 +241,7 @@ private[sql] class HiveSessionCatalog(
"elt", "hash", "java_method", "histogram_numeric",
"map_keys", "map_values",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
- "xpath", "xpath_boolean", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
+ "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string",
// table generating function