aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
blob: cd10c78311e1c22fdaee727ce4e4b993a6440b86 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.feature

import scala.util.Random

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}

class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  test("params") {
    ParamsSuite.checkParams(new Bucketizer)
  }

  test("Bucket continuous features, without -inf,inf") {
    // Check a set of valid feature values.
    val splits = Array(-0.5, 0.0, 0.5)
    val validData = Array(-0.5, -0.3, 0.0, 0.2)
    val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
    val dataFrame: DataFrame =
      spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")

    val bucketizer: Bucketizer = new Bucketizer()
      .setInputCol("feature")
      .setOutputCol("result")
      .setSplits(splits)

    bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
      case Row(x: Double, y: Double) =>
        assert(x === y,
          s"The feature value is not correct after bucketing.  Expected $y but found $x")
    }

    // Check for exceptions when using a set of invalid feature values.
    val invalidData1: Array[Double] = Array(-0.9) ++ validData
    val invalidData2 = Array(0.51) ++ validData
    val badDF1 = spark.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
    withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
      intercept[SparkException] {
        bucketizer.transform(badDF1).collect()
      }
    }
    val badDF2 = spark.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
    withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
      intercept[SparkException] {
        bucketizer.transform(badDF2).collect()
      }
    }
  }

  test("Bucket continuous features, with -inf,inf") {
    val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity)
    val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
    val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
    val dataFrame: DataFrame =
      spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")

    val bucketizer: Bucketizer = new Bucketizer()
      .setInputCol("feature")
      .setOutputCol("result")
      .setSplits(splits)

    bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
      case Row(x: Double, y: Double) =>
        assert(x === y,
          s"The feature value is not correct after bucketing.  Expected $y but found $x")
    }
  }

  test("Binary search correctness on hand-picked examples") {
    import BucketizerSuite.checkBinarySearch
    // length 3, with -inf
    checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0))
    // length 4
    checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0))
    // length 5
    checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0, 1.5))
    // length 3, with inf
    checkBinarySearch(Array(0.0, 1.0, Double.PositiveInfinity))
    // length 3, with -inf and inf
    checkBinarySearch(Array(Double.NegativeInfinity, 1.0, Double.PositiveInfinity))
    // length 4, with -inf and inf
    checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity))
  }

  test("Binary search correctness in contrast with linear search, on random data") {
    val data = Array.fill(100)(Random.nextDouble())
    val splits: Array[Double] = Double.NegativeInfinity +:
      Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity
    val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x)))
    val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
    assert(bsResult ~== lsResult absTol 1e-5)
  }

  test("read/write") {
    val t = new Bucketizer()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setSplits(Array(0.1, 0.8, 0.9))
    testDefaultReadWrite(t)
  }
}

private object BucketizerSuite extends SparkFunSuite {
  /** Brute force search for buckets.  Bucket i is defined by the range [split(i), split(i+1)). */
  def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
    require(feature >= splits.head)
    var i = 0
    val n = splits.length - 1
    while (i < n) {
      if (feature < splits(i + 1)) return i
      i += 1
    }
    throw new RuntimeException(
      s"linearSearchForBuckets failed to find bucket for feature value $feature")
  }

  /** Check all values in splits, plus values between all splits. */
  def checkBinarySearch(splits: Array[Double]): Unit = {
    def testFeature(feature: Double, expectedBucket: Double): Unit = {
      assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket,
        s"Expected feature value $feature to be in bucket $expectedBucket with splits:" +
          s" ${splits.mkString(", ")}")
    }
    var i = 0
    val n = splits.length - 1
    while (i < n) {
      // Split i should fall in bucket i.
      testFeature(splits(i), i)
      // Value between splits i,i+1 should be in i, which is also true if the (i+1)-th split is inf.
      testFeature((splits(i) + splits(i + 1)) / 2, i)
      i += 1
    }
  }
}