aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-09-30 08:18:48 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-09-30 08:18:48 -0700
commit8e491af52930886cbe0c54e7d67add3796ddb15f (patch)
treec72d7a99273d0eca03dd251f00780c105e569134
parent1fad5596885aab8b32d2307c0edecbae50d5bd7a (diff)
downloadspark-8e491af52930886cbe0c54e7d67add3796ddb15f.tar.gz
spark-8e491af52930886cbe0c54e7d67add3796ddb15f.tar.bz2
spark-8e491af52930886cbe0c54e7d67add3796ddb15f.zip
[SPARK-14077][ML][FOLLOW-UP] Revert change for NB Model's Load to maintain compatibility with the model stored before 2.0
## What changes were proposed in this pull request? Revert change for NB Model's Load to maintain compatibility with the model stored before 2.0 ## How was this patch tested? local build Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #15313 from zhengruifeng/revert_save_load.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala11
1 files changed, 7 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 0d652aa4c6..6775745167 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -25,7 +25,8 @@ import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
-import org.apache.spark.sql.Dataset
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType
@@ -362,9 +363,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head()
- val pi = data.getAs[Vector](0)
- val theta = data.getAs[Matrix](1)
+ val data = sparkSession.read.parquet(dataPath)
+ val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
+ val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
+ .select("pi", "theta")
+ .head()
val model = new NaiveBayesModel(metadata.uid, pi, theta)
DefaultParamsReader.getAndSetParams(model, metadata)