aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorVinod K C <vinod.kc@huawei.com>2015-09-08 14:44:05 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-08 14:44:05 -0700
commite6f8d3686016a305a747c5bcc85f46fd4c0cbe83 (patch)
treef51524e1617c96df7660dfcf39abbc1e316104ff /mllib
parent7a9dcbc91d55dbc0cbf4812319bde65f4509b467 (diff)
downloadspark-e6f8d3686016a305a747c5bcc85f46fd4c0cbe83.tar.gz
spark-e6f8d3686016a305a747c5bcc85f46fd4c0cbe83.tar.bz2
spark-e6f8d3686016a305a747c5bcc85f46fd4c0cbe83.zip
[SPARK-10468] [ MLLIB ] Verify schema before Dataframe select API call
Loader.checkSchema was called to verify the schema after dataframe.select(...). Schema verification should be done before dataframe.select(...) Author: Vinod K C <vinod.kc@huawei.com> Closes #8636 from vinodkc/fix_GaussianMixtureModel_load_verification.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala4
2 files changed, 2 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 7f6163e04b..a5902190d4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -168,10 +168,9 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
val dataPath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataFrame = sqlContext.read.parquet(dataPath)
- val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
-
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
+ val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
val (weights, gaussians) = dataArray.map {
case Row(weight: Double, mu: Vector, sigma: Matrix) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 36b124c5d2..58857c338f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -590,12 +590,10 @@ object Word2VecModel extends Loader[Word2VecModel] {
val dataPath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataFrame = sqlContext.read.parquet(dataPath)
-
- val dataArray = dataFrame.select("word", "vector").collect()
-
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
+ val dataArray = dataFrame.select("word", "vector").collect()
val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap
new Word2VecModel(word2VecMap)
}