diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-08-11 11:01:59 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-08-11 11:01:59 -0700 |
commit | 8cad854ef6a2066de5adffcca6b79a205ccfd5f3 (patch) | |
tree | 4e221a3cdb0f9727002d7c01ea349a44a45cef73 /mllib | |
parent | bce72797f3499f14455722600b0d0898d4fd87c9 (diff) | |
download | spark-8cad854ef6a2066de5adffcca6b79a205ccfd5f3.tar.gz spark-8cad854ef6a2066de5adffcca6b79a205ccfd5f3.tar.bz2 spark-8cad854ef6a2066de5adffcca6b79a205ccfd5f3.zip |
[SPARK-8345] [ML] Add an SQL node as a feature transformer
Implements the transforms which are defined by SQL statement.
Currently we only support SQL syntax like 'SELECT ... FROM __THIS__'
where '__THIS__' represents the underlying table of the input dataset.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #7465 from yanboliang/spark-8345 and squashes the following commits:
b403fcb [Yanbo Liang] address comments
0d4bb15 [Yanbo Liang] a better transformSchema() implementation
51eb9e7 [Yanbo Liang] Add an SQL node as a feature transformer
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala | 72 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala | 44 |
2 files changed, 116 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala new file mode 100644 index 0000000000..95e4305638 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Implements the transforms which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * where '__THIS__' represents the underlying table of the input dataset. + */ +@Experimental +class SQLTransformer (override val uid: String) extends Transformer { + + def this() = this(Identifiable.randomUID("sql")) + + /** + * SQL statement parameter. The statement is provided in string form. + * @group param + */ + final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") + + /** @group setParam */ + def setStatement(value: String): this.type = set(statement, value) + + /** @group getParam */ + def getStatement: String = $(statement) + + private val tableIdentifier: String = "__THIS__" + + override def transform(dataset: DataFrame): DataFrame = { + val tableName = Identifiable.randomUID(uid) + dataset.registerTempTable(tableName) + val realStatement = $(statement).replace(tableIdentifier, tableName) + val outputDF = dataset.sqlContext.sql(realStatement) + outputDF + } + + override def transformSchema(schema: StructType): StructType = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + val dummyRDD = sc.parallelize(Seq(Row.empty)) + val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) + dummyDF.registerTempTable(tableIdentifier) + val outputSchema = sqlContext.sql($(statement)).schema + outputSchema + } + + override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala new file mode 100644 index 0000000000..d19052881a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new SQLTransformer()) + } + + test("transform numeric data") { + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + val result = sqlTrans.transform(original) + val resultSchema = sqlTrans.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) + .toDF("id", "v1", "v2", "v3", "v4") + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } +} |