aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala24
1 files changed, 19 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 76021ad8f4..334410c962 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.util
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType}
/**
@@ -44,10 +44,10 @@ private[spark] object SchemaUtils {
}
/**
- * Check whether the given schema contains a column of one of the require data types.
- * @param colName column name
- * @param dataTypes required column data types
- */
+ * Check whether the given schema contains a column of one of the require data types.
+ * @param colName column name
+ * @param dataTypes required column data types
+ */
def checkColumnTypes(
schema: StructType,
colName: String,
@@ -61,6 +61,20 @@ private[spark] object SchemaUtils {
}
/**
+ * Check whether the given schema contains a column of the numeric data type.
+ * @param colName column name
+ */
+ def checkNumericType(
+ schema: StructType,
+ colName: String,
+ msg: String = ""): Unit = {
+ val actualDataType = schema(colName).dataType
+ val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
+ require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " +
+ s"NumericType but was actually of type $actualDataType.$message")
+ }
+
+ /**
* Appends a new column to the input schema. This fails if the given output column already exists.
* @param schema input schema
* @param colName new column name. If this column name is an empty string "", this method returns