aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorwujian <jan.chou.wu@gmail.com>2016-07-08 14:38:05 -0700
committerReynold Xin <rxin@databricks.com>2016-07-08 14:38:05 -0700
commitf5fef69143b2a83bb8b168b7417e92659af0c72c (patch)
tree322c0af0ab3388c4d68656e6dd675d41799b04be /sql/catalyst
parent142df4834bc33dc7b84b626c6ee3508ab1abe015 (diff)
downloadspark-f5fef69143b2a83bb8b168b7417e92659af0c72c.tar.gz
spark-f5fef69143b2a83bb8b168b7417e92659af0c72c.tar.bz2
spark-f5fef69143b2a83bb8b168b7417e92659af0c72c.zip
[SPARK-16281][SQL] Implement parse_url SQL function
## What changes were proposed in this pull request? This PR adds parse_url SQL functions in order to remove Hive fallback. A new implementation of #13999 ## How was this patch tested? Pass the exist tests including new testcases. Author: wujian <jan.chou.wu@gmail.com> Closes #14008 from janplus/SPARK-16281.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala150
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala51
3 files changed, 202 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 842c9c63ce..c8bbbf8853 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -288,6 +288,7 @@ object FunctionRegistry {
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
expression[JsonTuple]("json_tuple"),
+ expression[ParseUrl]("parse_url"),
expression[FormatString]("printf"),
expression[RegExpExtract]("regexp_extract"),
expression[RegExpReplace]("regexp_replace"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 894e12d4a3..61549c9a23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
+import java.net.{MalformedURLException, URL}
import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.util.{HashMap, Locale, Map => JMap}
+import java.util.regex.Pattern
import scala.collection.mutable.ArrayBuffer
@@ -654,6 +656,154 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
override def prettyName: String = "rpad"
}
+object ParseUrl {
+ private val HOST = UTF8String.fromString("HOST")
+ private val PATH = UTF8String.fromString("PATH")
+ private val QUERY = UTF8String.fromString("QUERY")
+ private val REF = UTF8String.fromString("REF")
+ private val PROTOCOL = UTF8String.fromString("PROTOCOL")
+ private val FILE = UTF8String.fromString("FILE")
+ private val AUTHORITY = UTF8String.fromString("AUTHORITY")
+ private val USERINFO = UTF8String.fromString("USERINFO")
+ private val REGEXPREFIX = "(&|^)"
+ private val REGEXSUBFIX = "=([^&]*)"
+}
+
+/**
+ * Extracts a part from a URL
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(url, partToExtract[, key]) - extracts a part from a URL",
+ extended = """Parts: HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, USERINFO.
+ Key specifies which query to extract.
+ Examples:
+ > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'HOST')
+ 'spark.apache.org'
+ > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY')
+ 'query=1'
+ > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY', 'query')
+ '1'""")
+case class ParseUrl(children: Seq[Expression])
+ extends Expression with ExpectsInputTypes with CodegenFallback {
+
+ override def nullable: Boolean = true
+ override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType)
+ override def dataType: DataType = StringType
+ override def prettyName: String = "parse_url"
+
+ // If the url is a constant, cache the URL object so that we don't need to convert url
+ // from UTF8String to String to URL for every row.
+ @transient private lazy val cachedUrl = children(0) match {
+ case Literal(url: UTF8String, _) if url ne null => getUrl(url)
+ case _ => null
+ }
+
+ // If the key is a constant, cache the Pattern object so that we don't need to convert key
+ // from UTF8String to String to StringBuilder to String to Pattern for every row.
+ @transient private lazy val cachedPattern = children(2) match {
+ case Literal(key: UTF8String, _) if key ne null => getPattern(key)
+ case _ => null
+ }
+
+ // If the partToExtract is a constant, cache the Extract part function so that we don't need
+ // to check the partToExtract for every row.
+ @transient private lazy val cachedExtractPartFunc = children(1) match {
+ case Literal(part: UTF8String, _) => getExtractPartFunc(part)
+ case _ => null
+ }
+
+ import ParseUrl._
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.size > 3 || children.size < 2) {
+ TypeCheckResult.TypeCheckFailure(s"$prettyName function requires two or three arguments")
+ } else {
+ super[ExpectsInputTypes].checkInputDataTypes()
+ }
+ }
+
+ private def getPattern(key: UTF8String): Pattern = {
+ Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX)
+ }
+
+ private def getUrl(url: UTF8String): URL = {
+ try {
+ new URL(url.toString)
+ } catch {
+ case e: MalformedURLException => null
+ }
+ }
+
+ private def getExtractPartFunc(partToExtract: UTF8String): URL => String = {
+ partToExtract match {
+ case HOST => _.getHost
+ case PATH => _.getPath
+ case QUERY => _.getQuery
+ case REF => _.getRef
+ case PROTOCOL => _.getProtocol
+ case FILE => _.getFile
+ case AUTHORITY => _.getAuthority
+ case USERINFO => _.getUserInfo
+ case _ => (url: URL) => null
+ }
+ }
+
+ private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = {
+ val m = pattern.matcher(query.toString)
+ if (m.find()) {
+ UTF8String.fromString(m.group(2))
+ } else {
+ null
+ }
+ }
+
+ private def extractFromUrl(url: URL, partToExtract: UTF8String): UTF8String = {
+ if (cachedExtractPartFunc ne null) {
+ UTF8String.fromString(cachedExtractPartFunc.apply(url))
+ } else {
+ UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url))
+ }
+ }
+
+ private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = {
+ if (cachedUrl ne null) {
+ extractFromUrl(cachedUrl, partToExtract)
+ } else {
+ val currentUrl = getUrl(url)
+ if (currentUrl ne null) {
+ extractFromUrl(currentUrl, partToExtract)
+ } else {
+ null
+ }
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]}
+ if (evaluated.contains(null)) return null
+ if (evaluated.size == 2) {
+ parseUrlWithoutKey(evaluated(0), evaluated(1))
+ } else {
+ // 3-arg, i.e. QUERY with key
+ assert(evaluated.size == 3)
+ if (evaluated(1) != QUERY) {
+ return null
+ }
+
+ val query = parseUrlWithoutKey(evaluated(0), evaluated(1))
+ if (query eq null) {
+ return null
+ }
+
+ if (cachedPattern ne null) {
+ extractValueFromQuery(query, cachedPattern)
+ } else {
+ extractValueFromQuery(query, getPattern(evaluated(2)))
+ }
+ }
+ }
+}
+
/**
* Returns the input formatted according do printf-style format strings
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 256ce85743..8f7b1041fa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -726,6 +726,57 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
}
+ test("ParseUrl") {
+ def checkParseUrl(expected: String, urlStr: String, partToExtract: String): Unit = {
+ checkEvaluation(
+ ParseUrl(Seq(Literal(urlStr), Literal(partToExtract))), expected)
+ }
+ def checkParseUrlWithKey(
+ expected: String,
+ urlStr: String,
+ partToExtract: String,
+ key: String): Unit = {
+ checkEvaluation(
+ ParseUrl(Seq(Literal(urlStr), Literal(partToExtract), Literal(key))), expected)
+ }
+
+ checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1", "HOST")
+ checkParseUrl("/path", "http://spark.apache.org/path?query=1", "PATH")
+ checkParseUrl("query=1", "http://spark.apache.org/path?query=1", "QUERY")
+ checkParseUrl("Ref", "http://spark.apache.org/path?query=1#Ref", "REF")
+ checkParseUrl("http", "http://spark.apache.org/path?query=1", "PROTOCOL")
+ checkParseUrl("/path?query=1", "http://spark.apache.org/path?query=1", "FILE")
+ checkParseUrl("spark.apache.org:8080", "http://spark.apache.org:8080/path?query=1", "AUTHORITY")
+ checkParseUrl("userinfo", "http://userinfo@spark.apache.org/path?query=1", "USERINFO")
+ checkParseUrlWithKey("1", "http://spark.apache.org/path?query=1", "QUERY", "query")
+
+ // Null checking
+ checkParseUrl(null, null, "HOST")
+ checkParseUrl(null, "http://spark.apache.org/path?query=1", null)
+ checkParseUrl(null, null, null)
+ checkParseUrl(null, "test", "HOST")
+ checkParseUrl(null, "http://spark.apache.org/path?query=1", "NO")
+ checkParseUrl(null, "http://spark.apache.org/path?query=1", "USERINFO")
+ checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "HOST", "query")
+ checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "quer")
+ checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", null)
+ checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "")
+
+ // exceptional cases
+ intercept[java.util.regex.PatternSyntaxException] {
+ evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"),
+ Literal("QUERY"), Literal("???"))))
+ }
+
+ // arguments checking
+ assert(ParseUrl(Seq(Literal("1"))).checkInputDataTypes().isFailure)
+ assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"), Literal("4")))
+ .checkInputDataTypes().isFailure)
+ assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes().isFailure)
+ assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes().isFailure)
+ assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal(3))).checkInputDataTypes().isFailure)
+ }
+
test("Sentences") {
val nullString = Literal.create(null, StringType)
checkEvaluation(Sentences(nullString, nullString, nullString), null)