aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorzlpmichelle <zlpmichelle@gmail.com>2016-06-30 00:50:14 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-06-30 00:50:14 -0700
commitb30a2dc7c50bfb70bd2b57be70530a9a9fa94a7a (patch)
tree8e27970fe282aa20feb6154e07177ed54421968c /mllib/src
parent2c3d96134dcc0428983eea087db7e91072215aea (diff)
downloadspark-b30a2dc7c50bfb70bd2b57be70530a9a9fa94a7a.tar.gz
spark-b30a2dc7c50bfb70bd2b57be70530a9a9fa94a7a.tar.bz2
spark-b30a2dc7c50bfb70bd2b57be70530a9a9fa94a7a.zip
[SPARK-16241][ML] model loading backward compatibility for ml NaiveBayes
## What changes were proposed in this pull request? model loading backward compatibility for ml NaiveBayes ## How was this patch tested? existing ut and manual test for loading models saved by Spark 1.6. Author: zlpmichelle <zlpmichelle@gmail.com> Closes #13940 from zlpmichelle/naivebayes.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala11
1 files changed, 7 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 7c340312df..c99ae30155 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -28,8 +28,9 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.{Dataset, Row}
/**
* Params for Naive Bayes Classifiers.
@@ -275,9 +276,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head()
- val pi = data.getAs[Vector](0)
- val theta = data.getAs[Matrix](1)
+ val data = sparkSession.read.parquet(dataPath)
+ val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
+ val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
+ .select("pi", "theta")
+ .head()
val model = new NaiveBayesModel(metadata.uid, pi, theta)
DefaultParamsReader.getAndSetParams(model, metadata)