aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorNick Pentreath <nickp@za.ibm.com>2016-05-18 21:13:12 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-18 21:13:12 +0200
commite8b79afa024123f9d4ceaf0a1043a7e37d913a8d (patch)
tree112148429c15f0eed3ef53e6253898521352a011 /python
parent3d1e67f903ab3512fcad82b94b1825578f8117c9 (diff)
downloadspark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.tar.gz
spark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.tar.bz2
spark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.zip
[SPARK-14891][ML] Add schema validation for ALS
This PR adds schema validation to `ml`'s ALS and ALSModel. Currently, no schema validation was performed as `transformSchema` was never called in `ALS.fit` or `ALSModel.transform`. Furthermore, due to no schema validation, if users passed in Long (or Float etc) ids, they would be silently cast to Int with no warning or error thrown. With this PR, ALS now supports all numeric types for `user`, `item`, and `rating` columns. The rating column is cast to `Float` and the user and item cols are cast to `Int` (as is the case currently) - however for user/item, the cast throws an error if the value is outside integer range. Behavior for rating col is unchanged (as it is not an issue). ## How was this patch tested? New test cases in `ALSSuite`. Author: Nick Pentreath <nickp@za.ibm.com> Closes #12762 from MLnick/SPARK-14891-als-validate-schema.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/recommendation.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index d7cb658465..86c00d9165 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -110,10 +110,10 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
typeConverter=TypeConverters.toBoolean)
alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference",
typeConverter=TypeConverters.toFloat)
- userCol = Param(Params._dummy(), "userCol", "column name for user ids",
- typeConverter=TypeConverters.toString)
- itemCol = Param(Params._dummy(), "itemCol", "column name for item ids",
- typeConverter=TypeConverters.toString)
+ userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " +
+ "the integer value range.", typeConverter=TypeConverters.toString)
+ itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " +
+ "the integer value range.", typeConverter=TypeConverters.toString)
ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings",
typeConverter=TypeConverters.toString)
nonnegative = Param(Params._dummy(), "nonnegative",