1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
|
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.types.StructType
/**
* :: Experimental ::
* Implements the transformations which are defined by SQL statement.
* Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...'
* where '__THIS__' represents the underlying table of the input dataset.
* The select clause specifies the fields, constants, and expressions to display in
* the output, it can be any select clause that Spark SQL supports. Users can also
* use Spark SQL built-in function and UDFs to operate on these selected columns.
* For example, [[SQLTransformer]] supports statements like:
* - SELECT a, a + b AS a_b FROM __THIS__
* - SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5
* - SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b
*/
@Experimental
@Since("1.6.0")
class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer
with DefaultParamsWritable {
@Since("1.6.0")
def this() = this(Identifiable.randomUID("sql"))
/**
* SQL statement parameter. The statement is provided in string form.
* @group param
*/
@Since("1.6.0")
final val statement: Param[String] = new Param[String](this, "statement", "SQL statement")
/** @group setParam */
@Since("1.6.0")
def setStatement(value: String): this.type = set(statement, value)
/** @group getParam */
@Since("1.6.0")
def getStatement: String = $(statement)
private val tableIdentifier: String = "__THIS__"
@Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
val tableName = Identifiable.randomUID(uid)
dataset.registerTempTable(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
val outputDF = dataset.sqlContext.sql(realStatement)
outputDF
}
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
dummyDF.registerTempTable(tableIdentifier)
val outputSchema = sqlContext.sql($(statement)).schema
outputSchema
}
@Since("1.6.0")
override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
}
@Since("1.6.0")
object SQLTransformer extends DefaultParamsReadable[SQLTransformer] {
@Since("1.6.0")
override def load(path: String): SQLTransformer = super.load(path)
}
|