1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis
import java.util.Locale
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
/**
* Rule that resolves table-valued function references.
*/
object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
/**
* List of argument names and their types, used to declare a function.
*/
private case class ArgumentList(args: (String, DataType)*) {
/**
* Try to cast the expressions to satisfy the expected types of this argument list. If there
* are any types that cannot be casted, then None is returned.
*/
def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = {
if (args.length == values.length) {
val casted = values.zip(args).map { case (value, (_, expectedType)) =>
TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType)
}
if (casted.forall(_.isDefined)) {
return Some(casted.map(_.get))
}
}
None
}
override def toString: String = {
args.map { a =>
s"${a._1}: ${a._2.typeName}"
}.mkString(", ")
}
}
/**
* A TVF maps argument lists to resolver functions that accept those arguments. Using a map
* here allows for function overloading.
*/
private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan]
/**
* TVF builder.
*/
private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan])
: (ArgumentList, Seq[Any] => LogicalPlan) = {
(ArgumentList(args: _*),
pf orElse {
case args =>
throw new IllegalArgumentException(
"Invalid arguments for resolved function: " + args.mkString(", "))
})
}
/**
* Internal registry of table-valued functions.
*/
private val builtinFunctions: Map[String, TVF] = Map(
"range" -> Map(
/* range(end) */
tvf("end" -> LongType) { case Seq(end: Long) =>
Range(0, end, 1, None)
},
/* range(start, end) */
tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) =>
Range(start, end, 1, None)
},
/* range(start, end, step) */
tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) {
case Seq(start: Long, end: Long, step: Long) =>
Range(start, end, step, None)
},
/* range(start, end, step, numPartitions) */
tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
"numPartitions" -> IntegerType) {
case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
Range(start, end, step, Some(numPartitions))
})
)
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
case Some(tvf) =>
val resolved = tvf.flatMap { case (argList, resolver) =>
argList.implicitCast(u.functionArgs) match {
case Some(casted) =>
Some(resolver(casted.map(_.eval())))
case _ =>
None
}
}
resolved.headOption.getOrElse {
val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ")
u.failAnalysis(
s"""error: table-valued function ${u.functionName} with alternatives:
|${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")}
|cannot be applied to: (${argTypes})""".stripMargin)
}
case _ =>
u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
}
}
}
|