diff options
author | wujian <jan.chou.wu@gmail.com> | 2016-07-08 14:38:05 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-07-08 14:38:05 -0700 |
commit | f5fef69143b2a83bb8b168b7417e92659af0c72c (patch) | |
tree | 322c0af0ab3388c4d68656e6dd675d41799b04be /sql/catalyst/src/main/scala | |
parent | 142df4834bc33dc7b84b626c6ee3508ab1abe015 (diff) | |
download | spark-f5fef69143b2a83bb8b168b7417e92659af0c72c.tar.gz spark-f5fef69143b2a83bb8b168b7417e92659af0c72c.tar.bz2 spark-f5fef69143b2a83bb8b168b7417e92659af0c72c.zip |
[SPARK-16281][SQL] Implement parse_url SQL function
## What changes were proposed in this pull request?
This PR adds parse_url SQL functions in order to remove Hive fallback.
A new implementation of #13999
## How was this patch tested?
Pass the exist tests including new testcases.
Author: wujian <jan.chou.wu@gmail.com>
Closes #14008 from janplus/SPARK-16281.
Diffstat (limited to 'sql/catalyst/src/main/scala')
2 files changed, 151 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 842c9c63ce..c8bbbf8853 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -288,6 +288,7 @@ object FunctionRegistry { expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), expression[JsonTuple]("json_tuple"), + expression[ParseUrl]("parse_url"), expression[FormatString]("printf"), expression[RegExpExtract]("regexp_extract"), expression[RegExpReplace]("regexp_replace"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 894e12d4a3..61549c9a23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import java.net.{MalformedURLException, URL} import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols} import java.util.{HashMap, Locale, Map => JMap} +import java.util.regex.Pattern import scala.collection.mutable.ArrayBuffer @@ -654,6 +656,154 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) override def prettyName: String = "rpad" } +object ParseUrl { + private val HOST = UTF8String.fromString("HOST") + private val PATH = UTF8String.fromString("PATH") + private val QUERY = UTF8String.fromString("QUERY") + private val REF = UTF8String.fromString("REF") + private val PROTOCOL = UTF8String.fromString("PROTOCOL") + private val FILE = UTF8String.fromString("FILE") + private val AUTHORITY = UTF8String.fromString("AUTHORITY") + private val USERINFO = UTF8String.fromString("USERINFO") + private val REGEXPREFIX = "(&|^)" + private val REGEXSUBFIX = "=([^&]*)" +} + +/** + * Extracts a part from a URL + */ +@ExpressionDescription( + usage = "_FUNC_(url, partToExtract[, key]) - extracts a part from a URL", + extended = """Parts: HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, USERINFO. + Key specifies which query to extract. + Examples: + > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'HOST') + 'spark.apache.org' + > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY') + 'query=1' + > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY', 'query') + '1'""") +case class ParseUrl(children: Seq[Expression]) + extends Expression with ExpectsInputTypes with CodegenFallback { + + override def nullable: Boolean = true + override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType + override def prettyName: String = "parse_url" + + // If the url is a constant, cache the URL object so that we don't need to convert url + // from UTF8String to String to URL for every row. + @transient private lazy val cachedUrl = children(0) match { + case Literal(url: UTF8String, _) if url ne null => getUrl(url) + case _ => null + } + + // If the key is a constant, cache the Pattern object so that we don't need to convert key + // from UTF8String to String to StringBuilder to String to Pattern for every row. + @transient private lazy val cachedPattern = children(2) match { + case Literal(key: UTF8String, _) if key ne null => getPattern(key) + case _ => null + } + + // If the partToExtract is a constant, cache the Extract part function so that we don't need + // to check the partToExtract for every row. + @transient private lazy val cachedExtractPartFunc = children(1) match { + case Literal(part: UTF8String, _) => getExtractPartFunc(part) + case _ => null + } + + import ParseUrl._ + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size > 3 || children.size < 2) { + TypeCheckResult.TypeCheckFailure(s"$prettyName function requires two or three arguments") + } else { + super[ExpectsInputTypes].checkInputDataTypes() + } + } + + private def getPattern(key: UTF8String): Pattern = { + Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX) + } + + private def getUrl(url: UTF8String): URL = { + try { + new URL(url.toString) + } catch { + case e: MalformedURLException => null + } + } + + private def getExtractPartFunc(partToExtract: UTF8String): URL => String = { + partToExtract match { + case HOST => _.getHost + case PATH => _.getPath + case QUERY => _.getQuery + case REF => _.getRef + case PROTOCOL => _.getProtocol + case FILE => _.getFile + case AUTHORITY => _.getAuthority + case USERINFO => _.getUserInfo + case _ => (url: URL) => null + } + } + + private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = { + val m = pattern.matcher(query.toString) + if (m.find()) { + UTF8String.fromString(m.group(2)) + } else { + null + } + } + + private def extractFromUrl(url: URL, partToExtract: UTF8String): UTF8String = { + if (cachedExtractPartFunc ne null) { + UTF8String.fromString(cachedExtractPartFunc.apply(url)) + } else { + UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url)) + } + } + + private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = { + if (cachedUrl ne null) { + extractFromUrl(cachedUrl, partToExtract) + } else { + val currentUrl = getUrl(url) + if (currentUrl ne null) { + extractFromUrl(currentUrl, partToExtract) + } else { + null + } + } + } + + override def eval(input: InternalRow): Any = { + val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]} + if (evaluated.contains(null)) return null + if (evaluated.size == 2) { + parseUrlWithoutKey(evaluated(0), evaluated(1)) + } else { + // 3-arg, i.e. QUERY with key + assert(evaluated.size == 3) + if (evaluated(1) != QUERY) { + return null + } + + val query = parseUrlWithoutKey(evaluated(0), evaluated(1)) + if (query eq null) { + return null + } + + if (cachedPattern ne null) { + extractValueFromQuery(query, cachedPattern) + } else { + extractValueFromQuery(query, getPattern(evaluated(2))) + } + } + } +} + /** * Returns the input formatted according do printf-style format strings */ |