diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 76021ad8f4..334410c962 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.util -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType} /** @@ -44,10 +44,10 @@ private[spark] object SchemaUtils { } /** - * Check whether the given schema contains a column of one of the require data types. - * @param colName column name - * @param dataTypes required column data types - */ + * Check whether the given schema contains a column of one of the require data types. + * @param colName column name + * @param dataTypes required column data types + */ def checkColumnTypes( schema: StructType, colName: String, @@ -61,6 +61,20 @@ private[spark] object SchemaUtils { } /** + * Check whether the given schema contains a column of the numeric data type. + * @param colName column name + */ + def checkNumericType( + schema: StructType, + colName: String, + msg: String = ""): Unit = { + val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" + require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " + + s"NumericType but was actually of type $actualDataType.$message") + } + + /** * Appends a new column to the input schema. This fails if the given output column already exists. * @param schema input schema * @param colName new column name. If this column name is an empty string "", this method returns |