aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main
diff options
context:
space:
mode:
authorpetermaxlee <petermaxlee@gmail.com>2016-07-11 13:28:34 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-11 13:28:34 +0800
commit82f0874453991510216779926d795b0a4e07e854 (patch)
tree7606c863e631a1fb7822ce38348cca8f87855513 /sql/catalyst/src/main
parent52b5bb0b7fabe6cc949f514c548f9fbc6a4fa181 (diff)
downloadspark-82f0874453991510216779926d795b0a4e07e854.tar.gz
spark-82f0874453991510216779926d795b0a4e07e854.tar.bz2
spark-82f0874453991510216779926d795b0a4e07e854.zip
[SPARK-16318][SQL] Implement all remaining xpath functions
## What changes were proposed in this pull request? This patch implements all remaining xpath functions that Hive supports and not natively supported in Spark: xpath_int, xpath_short, xpath_long, xpath_float, xpath_double, xpath_string, and xpath. ## How was this patch tested? Added unit tests and end-to-end tests. Author: petermaxlee <petermaxlee@gmail.com> Closes #13991 from petermaxlee/SPARK-16318.
Diffstat (limited to 'sql/catalyst/src/main')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala58
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala174
4 files changed, 190 insertions, 67 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
index 410e9e51ba..d224332d8a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
@@ -43,7 +43,7 @@ public class UDFXPathUtil {
private XPathExpression expression = null;
private String oldPath = null;
- public Object eval(String xml, String path, QName qname) {
+ public Object eval(String xml, String path, QName qname) throws XPathExpressionException {
if (xml == null || path == null || qname == null) {
return null;
}
@@ -56,7 +56,7 @@ public class UDFXPathUtil {
try {
expression = xpath.compile(path);
} catch (XPathExpressionException e) {
- expression = null;
+ throw new RuntimeException("Invalid XPath '" + path + "'" + e.getMessage(), e);
}
oldPath = path;
}
@@ -66,31 +66,30 @@ public class UDFXPathUtil {
}
reader.set(xml);
-
try {
return expression.evaluate(inputSource, qname);
} catch (XPathExpressionException e) {
- throw new RuntimeException("Invalid expression '" + oldPath + "'", e);
+ throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e);
}
}
- public Boolean evalBoolean(String xml, String path) {
+ public Boolean evalBoolean(String xml, String path) throws XPathExpressionException {
return (Boolean) eval(xml, path, XPathConstants.BOOLEAN);
}
- public String evalString(String xml, String path) {
+ public String evalString(String xml, String path) throws XPathExpressionException {
return (String) eval(xml, path, XPathConstants.STRING);
}
- public Double evalNumber(String xml, String path) {
+ public Double evalNumber(String xml, String path) throws XPathExpressionException {
return (Double) eval(xml, path, XPathConstants.NUMBER);
}
- public Node evalNode(String xml, String path) {
+ public Node evalNode(String xml, String path) throws XPathExpressionException {
return (Node) eval(xml, path, XPathConstants.NODE);
}
- public NodeList evalNodeList(String xml, String path) {
+ public NodeList evalNodeList(String xml, String path) throws XPathExpressionException {
return (NodeList) eval(xml, path, XPathConstants.NODESET);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index c8bbbf8853..54568b7445 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -310,7 +310,15 @@ object FunctionRegistry {
expression[UnBase64]("unbase64"),
expression[Unhex]("unhex"),
expression[Upper]("upper"),
+ expression[XPathList]("xpath"),
expression[XPathBoolean]("xpath_boolean"),
+ expression[XPathDouble]("xpath_double"),
+ expression[XPathDouble]("xpath_number"),
+ expression[XPathFloat]("xpath_float"),
+ expression[XPathInt]("xpath_int"),
+ expression[XPathLong]("xpath_long"),
+ expression[XPathShort]("xpath_short"),
+ expression[XPathString]("xpath_string"),
// datetime functions
expression[AddMonths]("add_months"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala
deleted file mode 100644
index 2a5256c7f5..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.xml
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, StringType}
-import org.apache.spark.unsafe.types.UTF8String
-
-
-@ExpressionDescription(
- usage = "_FUNC_(xml, xpath) - Evaluates a boolean xpath expression.",
- extended = "> SELECT _FUNC_('<a><b>1</b></a>','a/b');\ntrue")
-case class XPathBoolean(xml: Expression, path: Expression)
- extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
-
- @transient private lazy val xpathUtil = new UDFXPathUtil
-
- // If the path is a constant, cache the path string so that we don't need to convert path
- // from UTF8String to String for every row.
- @transient lazy val pathLiteral: String = path match {
- case Literal(str: UTF8String, _) => str.toString
- case _ => null
- }
-
- override def prettyName: String = "xpath_boolean"
-
- override def dataType: DataType = BooleanType
-
- override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
-
- override def left: Expression = xml
- override def right: Expression = path
-
- override protected def nullSafeEval(xml: Any, path: Any): Any = {
- val xmlString = xml.asInstanceOf[UTF8String].toString
- if (pathLiteral ne null) {
- xpathUtil.evalBoolean(xmlString, pathLiteral)
- } else {
- xpathUtil.evalBoolean(xmlString, path.asInstanceOf[UTF8String].toString)
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
new file mode 100644
index 0000000000..47f039e6a4
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.xml
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Base class for xpath_boolean, xpath_double, xpath_int, etc.
+ *
+ * This is not the world's most efficient implementation due to type conversion, but works.
+ */
+abstract class XPathExtract extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
+ override def left: Expression = xml
+ override def right: Expression = path
+
+ /** XPath expressions are always nullable, e.g. if the xml string is empty. */
+ override def nullable: Boolean = true
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (!path.foldable) {
+ TypeCheckFailure("path should be a string literal")
+ } else {
+ super.checkInputDataTypes()
+ }
+ }
+
+ @transient protected lazy val xpathUtil = new UDFXPathUtil
+ @transient protected lazy val pathString: String = path.eval().asInstanceOf[UTF8String].toString
+
+ /** Concrete implementations need to override the following three methods. */
+ def xml: Expression
+ def path: Expression
+}
+
+@ExpressionDescription(
+ usage = "_FUNC_(xml, xpath) - Evaluates a boolean xpath expression.",
+ extended = "> SELECT _FUNC_('<a><b>1</b></a>','a/b');\ntrue")
+case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract {
+
+ override def prettyName: String = "xpath_boolean"
+ override def dataType: DataType = BooleanType
+
+ override def nullSafeEval(xml: Any, path: Any): Any = {
+ xpathUtil.evalBoolean(xml.asInstanceOf[UTF8String].toString, pathString)
+ }
+}
+
+@ExpressionDescription(
+ usage = "_FUNC_(xml, xpath) - Returns a short value that matches the xpath expression",
+ extended = "> SELECT _FUNC_('<a><b>1</b><b>2</b></a>','sum(a/b)');\n3")
+case class XPathShort(xml: Expression, path: Expression) extends XPathExtract {
+ override def prettyName: String = "xpath_int"
+ override def dataType: DataType = ShortType
+
+ override def nullSafeEval(xml: Any, path: Any): Any = {
+ val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
+ if (ret eq null) null else ret.shortValue()
+ }
+}
+
+@ExpressionDescription(
+ usage = "_FUNC_(xml, xpath) - Returns an integer value that matches the xpath expression",
+ extended = "> SELECT _FUNC_('<a><b>1</b><b>2</b></a>','sum(a/b)');\n3")
+case class XPathInt(xml: Expression, path: Expression) extends XPathExtract {
+ override def prettyName: String = "xpath_int"
+ override def dataType: DataType = IntegerType
+
+ override def nullSafeEval(xml: Any, path: Any): Any = {
+ val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
+ if (ret eq null) null else ret.intValue()
+ }
+}
+
+@ExpressionDescription(
+ usage = "_FUNC_(xml, xpath) - Returns a long value that matches the xpath expression",
+ extended = "> SELECT _FUNC_('<a><b>1</b><b>2</b></a>','sum(a/b)');\n3")
+case class XPathLong(xml: Expression, path: Expression) extends XPathExtract {
+ override def prettyName: String = "xpath_long"
+ override def dataType: DataType = LongType
+
+ override def nullSafeEval(xml: Any, path: Any): Any = {
+ val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
+ if (ret eq null) null else ret.longValue()
+ }
+}
+
+@ExpressionDescription(
+ usage = "_FUNC_(xml, xpath) - Returns a float value that matches the xpath expression",
+ extended = "> SELECT _FUNC_('<a><b>1</b><b>2</b></a>','sum(a/b)');\n3.0")
+case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract {
+ override def prettyName: String = "xpath_float"
+ override def dataType: DataType = FloatType
+
+ override def nullSafeEval(xml: Any, path: Any): Any = {
+ val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
+ if (ret eq null) null else ret.floatValue()
+ }
+}
+
+@ExpressionDescription(
+ usage = "_FUNC_(xml, xpath) - Returns a double value that matches the xpath expression",
+ extended = "> SELECT _FUNC_('<a><b>1</b><b>2</b></a>','sum(a/b)');\n3.0")
+case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract {
+ override def prettyName: String = "xpath_float"
+ override def dataType: DataType = DoubleType
+
+ override def nullSafeEval(xml: Any, path: Any): Any = {
+ val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
+ if (ret eq null) null else ret.doubleValue()
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(xml, xpath) - Returns the text contents of the first xml node that matches the xpath expression",
+ extended = "> SELECT _FUNC_('<a><b>b</b><c>cc</c></a>','a/c');\ncc")
+// scalastyle:on line.size.limit
+case class XPathString(xml: Expression, path: Expression) extends XPathExtract {
+ override def prettyName: String = "xpath_string"
+ override def dataType: DataType = StringType
+
+ override def nullSafeEval(xml: Any, path: Any): Any = {
+ val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString)
+ UTF8String.fromString(ret)
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(xml, xpath) - Returns a string array of values within xml nodes that match the xpath expression",
+ extended = "> SELECT _FUNC_('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()');\n['b1','b2','b3']")
+// scalastyle:on line.size.limit
+case class XPathList(xml: Expression, path: Expression) extends XPathExtract {
+ override def prettyName: String = "xpath"
+ override def dataType: DataType = ArrayType(StringType, containsNull = false)
+
+ override def nullSafeEval(xml: Any, path: Any): Any = {
+ val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString)
+ if (nodeList ne null) {
+ val ret = new Array[UTF8String](nodeList.getLength)
+ var i = 0
+ while (i < nodeList.getLength) {
+ ret(i) = UTF8String.fromString(nodeList.item(i).getNodeValue)
+ i += 1
+ }
+ new GenericArrayData(ret)
+ } else {
+ null
+ }
+ }
+}