aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala27
2 files changed, 31 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 9f3d2ca6db..28cbe1cb01 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -186,8 +186,10 @@ class MinMaxScalerModel private[ml] (
val size = values.length
var i = 0
while (i < size) {
- val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5
- values(i) = raw * scale + $(min)
+ if (!values(i).isNaN) {
+ val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5
+ values(i) = raw * scale + $(min)
+ }
i += 1
}
Vectors.dense(values)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 5da8471175..9f376b7003 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -90,4 +90,31 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
assert(newInstance.originalMin === instance.originalMin)
assert(newInstance.originalMax === instance.originalMax)
}
+
+ test("MinMaxScaler should remain NaN value") {
+ val data = Array(
+ Vectors.dense(1, Double.NaN, 2.0, 2.0),
+ Vectors.dense(2, 2.0, 0.0, 3.0),
+ Vectors.dense(3, Double.NaN, 0.0, 1.0),
+ Vectors.dense(6, 2.0, 2.0, Double.NaN))
+
+ val expected: Array[Vector] = Array(
+ Vectors.dense(-5.0, Double.NaN, 5.0, 0.0),
+ Vectors.dense(-3.0, 0.0, -5.0, 5.0),
+ Vectors.dense(-1.0, Double.NaN, -5.0, -5.0),
+ Vectors.dense(5.0, 0.0, 5.0, Double.NaN))
+
+ val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
+ val scaler = new MinMaxScaler()
+ .setInputCol("features")
+ .setOutputCol("scaled")
+ .setMin(-5)
+ .setMax(5)
+
+ val model = scaler.fit(df)
+ model.transform(df).select("expected", "scaled").collect()
+ .foreach { case Row(vector1: Vector, vector2: Vector) =>
+ assert(vector1.equals(vector2), "Transformed vector is different with expected.")
+ }
+ }
}