aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala29
2 files changed, 35 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index e9df161c00..fa5013d3c9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
-import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
/**
* :: Experimental ::
@@ -70,7 +70,8 @@ class OneHotEncoder(override val uid: String) extends Transformer
val inputColName = $(inputCol)
val outputColName = $(outputCol)
- SchemaUtils.checkColumnType(schema, inputColName, DoubleType)
+ require(schema(inputColName).dataType.isInstanceOf[NumericType],
+ s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
val inputFields = schema.fields
require(!inputFields.exists(_.name == outputColName),
s"Output column $outputColName already exists.")
@@ -133,7 +134,9 @@ class OneHotEncoder(override val uid: String) extends Transformer
val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0))
.aggregate(0.0)(
(m, x) => {
- assert(x >=0.0 && x == x.toInt,
+ assert(x <= Int.MaxValue,
+ s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x")
+ assert(x >= 0.0 && x == x.toInt,
s"Values from column $inputColName must be indices, but got $x.")
math.max(m, x)
},
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index e238b33ed8..49803aef71 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types._
class OneHotEncoderSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -111,4 +112,32 @@ class OneHotEncoderSuite
.setDropLast(false)
testDefaultReadWrite(t)
}
+
+ test("OneHotEncoder with varying types") {
+ val df = stringIndexed()
+ val dfWithTypes = df
+ .withColumn("shortLabel", df("labelIndex").cast(ShortType))
+ .withColumn("longLabel", df("labelIndex").cast(LongType))
+ .withColumn("intLabel", df("labelIndex").cast(IntegerType))
+ .withColumn("floatLabel", df("labelIndex").cast(FloatType))
+ .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0)))
+ val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel",
+ "floatLabel", "decimalLabel")
+ for (col <- cols) {
+ val encoder = new OneHotEncoder()
+ .setInputCol(col)
+ .setOutputCol("labelVec")
+ .setDropLast(false)
+ val encoded = encoder.transform(dfWithTypes)
+
+ val output = encoded.select("id", "labelVec").rdd.map { r =>
+ val vec = r.getAs[Vector](1)
+ (r.getInt(0), vec(0), vec(1), vec(2))
+ }.collect().toSet
+ // a -> 0, b -> 2, c -> 1
+ val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
+ (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
+ assert(output === expected)
+ }
+ }
}