aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-10-02 10:19:18 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-10-02 10:19:18 -0700
commit23a9448c04da7130d6c41c37f9fdf03184422dc8 (patch)
tree278ae3ef765cfaad0e87628ef99f0384339bbd87 /mllib/src/test/scala/org
parent2a717821bbb026d4d8c43d31b3300721357951c6 (diff)
downloadspark-23a9448c04da7130d6c41c37f9fdf03184422dc8.tar.gz
spark-23a9448c04da7130d6c41c37f9fdf03184422dc8.tar.bz2
spark-23a9448c04da7130d6c41c37f9fdf03184422dc8.zip
[SPARK-5890] [ML] Add feature discretizer
JIRA issue [here](https://issues.apache.org/jira/browse/SPARK-5890). I borrow the code of `findSplits` from `RandomForest`. I don't think it's good to call it from `RandomForest` directly. Author: Xusen Yin <yinxusen@gmail.com> Closes #5779 from yinxusen/SPARK-5890.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala98
1 files changed, 98 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
new file mode 100644
index 0000000000..b2bdd8935f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.{SparkContext, SparkFunSuite}
+
+class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+ import org.apache.spark.ml.feature.QuantileDiscretizerSuite._
+
+ test("Test quantile discretizer") {
+ checkDiscretizedData(sc,
+ Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
+ 10,
+ Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
+ Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
+
+ checkDiscretizedData(sc,
+ Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
+ 4,
+ Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
+ Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
+
+ checkDiscretizedData(sc,
+ Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
+ 3,
+ Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2),
+ Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity"))
+
+ checkDiscretizedData(sc,
+ Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
+ 2,
+ Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1),
+ Array("-Infinity, 2.0", "2.0, Infinity"))
+
+ }
+
+ test("Test getting splits") {
+ val splitTestPoints = Array(
+ Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
+ Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
+ Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
+ Array(Double.NegativeInfinity, Double.PositiveInfinity)
+ -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
+ Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
+ Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity),
+ Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity)
+ )
+ for ((ori, res) <- splitTestPoints) {
+ assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
+ }
+ }
+}
+
+private object QuantileDiscretizerSuite extends SparkFunSuite {
+
+ def checkDiscretizedData(
+ sc: SparkContext,
+ data: Array[Double],
+ numBucket: Int,
+ expectedResult: Array[Double],
+ expectedAttrs: Array[String]): Unit = {
+ val sqlCtx = SQLContext.getOrCreate(sc)
+ import sqlCtx.implicits._
+
+ val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
+ val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
+ .setNumBuckets(numBucket)
+ val result = discretizer.fit(df).transform(df)
+
+ val transformedFeatures = result.select("result").collect()
+ .map { case Row(transformedFeature: Double) => transformedFeature }
+ val transformedAttrs = Attribute.fromStructField(result.schema("result"))
+ .asInstanceOf[NominalAttribute].values.get
+
+ assert(transformedFeatures === expectedResult,
+ "Transformed features do not equal expected features.")
+ assert(transformedAttrs === expectedAttrs,
+ "Transformed attributes do not equal expected attributes.")
+ }
+}