aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-08-11 11:01:59 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-11 11:01:59 -0700
commit8cad854ef6a2066de5adffcca6b79a205ccfd5f3 (patch)
tree4e221a3cdb0f9727002d7c01ea349a44a45cef73 /mllib/src/test
parentbce72797f3499f14455722600b0d0898d4fd87c9 (diff)
downloadspark-8cad854ef6a2066de5adffcca6b79a205ccfd5f3.tar.gz
spark-8cad854ef6a2066de5adffcca6b79a205ccfd5f3.tar.bz2
spark-8cad854ef6a2066de5adffcca6b79a205ccfd5f3.zip
[SPARK-8345] [ML] Add an SQL node as a feature transformer
Implements the transforms which are defined by SQL statement. Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' where '__THIS__' represents the underlying table of the input dataset. Author: Yanbo Liang <ybliang8@gmail.com> Closes #7465 from yanboliang/spark-8345 and squashes the following commits: b403fcb [Yanbo Liang] address comments 0d4bb15 [Yanbo Liang] a better transformSchema() implementation 51eb9e7 [Yanbo Liang] Add an SQL node as a feature transformer
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala44
1 files changed, 44 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
new file mode 100644
index 0000000000..d19052881a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("params") {
+ ParamsSuite.checkParams(new SQLTransformer())
+ }
+
+ test("transform numeric data") {
+ val original = sqlContext.createDataFrame(
+ Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
+ val sqlTrans = new SQLTransformer().setStatement(
+ "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
+ val result = sqlTrans.transform(original)
+ val resultSchema = sqlTrans.transformSchema(original.schema)
+ val expected = sqlContext.createDataFrame(
+ Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)))
+ .toDF("id", "v1", "v2", "v3", "v4")
+ assert(result.schema.toString == resultSchema.toString)
+ assert(resultSchema == expected.schema)
+ assert(result.collect().toSeq == expected.collect().toSeq)
+ }
+}