aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/core')
-rw-r--r--src/main/scala/xyz/driver/core/app.scala11
-rw-r--r--src/main/scala/xyz/driver/core/auth.scala1
-rw-r--r--src/main/scala/xyz/driver/core/file.scala155
-rw-r--r--src/main/scala/xyz/driver/core/file/FileSystemStorage.scala57
-rw-r--r--src/main/scala/xyz/driver/core/file/GcsStorage.scala82
-rw-r--r--src/main/scala/xyz/driver/core/file/S3Storage.scala66
-rw-r--r--src/main/scala/xyz/driver/core/file/package.scala60
-rw-r--r--src/main/scala/xyz/driver/core/rest.scala200
8 files changed, 423 insertions, 209 deletions
diff --git a/src/main/scala/xyz/driver/core/app.scala b/src/main/scala/xyz/driver/core/app.scala
index f7731e3..e35c300 100644
--- a/src/main/scala/xyz/driver/core/app.scala
+++ b/src/main/scala/xyz/driver/core/app.scala
@@ -26,6 +26,8 @@ import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider}
import scala.compat.Platform.ConcurrentModificationException
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future}
+import scalaz.Scalaz.stringInstance
+import scalaz.syntax.equal._
object app {
@@ -59,6 +61,9 @@ object app {
}
}
+ private def extractHeader(request: HttpRequest)(headerName: String): Option[String] =
+ request.headers.find(_.name().toLowerCase === headerName).map(_.value())
+
protected def bindHttp(modules: Seq[Module]): Unit = {
val serviceTypes = modules.flatMap(_.routeTypes)
val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme), version, actorSystem, serviceTypes, config)
@@ -74,6 +79,12 @@ object app {
val trackingId = rest.extractTrackingId(ctx.request)
MDC.put("trackingId", trackingId)
MDC.put("origin", origin)
+ MDC.put("xForwardedFor",
+ extractHeader(ctx.request)("x-forwarded-for")
+ .orElse(extractHeader(ctx.request)("x_forwarded_for"))
+ .getOrElse("unknown"))
+ MDC.put("remoteAddress", extractHeader(ctx.request)("remote-address").getOrElse("unknown"))
+ MDC.put("userAgent", extractHeader(ctx.request)("user-agent").getOrElse("unknown"))
MDC.put("ip", ip.toOption.map(_.getHostAddress).getOrElse("unknown"))
def requestLogging: Future[Unit] = Future {
diff --git a/src/main/scala/xyz/driver/core/auth.scala b/src/main/scala/xyz/driver/core/auth.scala
index f9a1a57..5dea2db 100644
--- a/src/main/scala/xyz/driver/core/auth.scala
+++ b/src/main/scala/xyz/driver/core/auth.scala
@@ -23,6 +23,7 @@ object auth {
final case class AuthToken(value: String)
final case class RefreshToken(value: String)
+ final case class PermissionsToken(value: String)
final case class PasswordHash(value: String)
diff --git a/src/main/scala/xyz/driver/core/file.scala b/src/main/scala/xyz/driver/core/file.scala
deleted file mode 100644
index dcc4b87..0000000
--- a/src/main/scala/xyz/driver/core/file.scala
+++ /dev/null
@@ -1,155 +0,0 @@
-package xyz.driver.core
-
-import java.io.File
-import java.nio.file.{Path, Paths}
-import java.util.UUID._
-
-import com.amazonaws.services.s3.AmazonS3
-import com.amazonaws.services.s3.model.{Bucket, GetObjectRequest, ListObjectsV2Request}
-import xyz.driver.core.time.Time
-
-import scala.concurrent.{ExecutionContext, Future}
-import scalaz.{ListT, OptionT}
-
-object file {
-
- final case class FileLink(
- name: Name[File],
- location: Path,
- revision: Revision[File],
- lastModificationDate: Time,
- fileSize: Long
- )
-
- trait FileService {
-
- def getFileLink(id: Name[File]): FileLink
-
- def getFile(fileLink: FileLink): File
- }
-
- trait FileStorage {
-
- def upload(localSource: File, destination: Path): Future[Unit]
-
- def download(filePath: Path): OptionT[Future, File]
-
- def delete(filePath: Path): Future[Unit]
-
- def list(path: Path): ListT[Future, FileLink]
-
- /** List of characters to avoid in S3 (I would say file names in general)
- *
- * @see http://stackoverflow.com/questions/7116450/what-are-valid-s3-key-names-that-can-be-accessed-via-the-s3-rest-api
- */
- private val illegalChars = "\\^`><{}][#%~|&@:,$=+?; "
-
- protected def checkSafeFileName[T](filePath: Path)(f: => T): T = {
- filePath.toString.find(c => illegalChars.contains(c)) match {
- case Some(illegalCharacter) =>
- throw new IllegalArgumentException(s"File name cannot contain character `$illegalCharacter`")
- case None => f
- }
- }
- }
-
- class S3Storage(s3: AmazonS3, bucket: Name[Bucket], executionContext: ExecutionContext) extends FileStorage {
- implicit private val execution = executionContext
-
- def upload(localSource: File, destination: Path): Future[Unit] = Future {
- checkSafeFileName(destination) {
- val _ = s3.putObject(bucket.value, destination.toString, localSource).getETag
- }
- }
-
- def download(filePath: Path): OptionT[Future, File] =
- OptionT.optionT(Future {
- val tempDir = System.getProperty("java.io.tmpdir")
- val randomFolderName = randomUUID().toString
- val tempDestinationFile = new File(Paths.get(tempDir, randomFolderName, filePath.toString).toString)
-
- if (!tempDestinationFile.getParentFile.mkdirs()) {
- throw new Exception(s"Failed to create temp directory to download file `$tempDestinationFile`")
- } else {
- Option(s3.getObject(new GetObjectRequest(bucket.value, filePath.toString), tempDestinationFile)).map { _ =>
- tempDestinationFile
- }
- }
- })
-
- def delete(filePath: Path): Future[Unit] = Future {
- s3.deleteObject(bucket.value, filePath.toString)
- }
-
- def list(path: Path): ListT[Future, FileLink] =
- ListT.listT(Future {
- import scala.collection.JavaConverters._
- val req = new ListObjectsV2Request().withBucketName(bucket.value).withPrefix(path.toString).withMaxKeys(2)
-
- def isInSubFolder(path: Path)(fileLink: FileLink) =
- fileLink.location.toString.replace(path.toString + "/", "").contains("/")
-
- Iterator.continually(s3.listObjectsV2(req)).takeWhile { result =>
- req.setContinuationToken(result.getNextContinuationToken)
- result.isTruncated
- } flatMap { result =>
- result.getObjectSummaries.asScala.toList.map { summary =>
- FileLink(
- Name[File](summary.getKey),
- Paths.get(path.toString + "/" + summary.getKey),
- Revision[File](summary.getETag),
- Time(summary.getLastModified.getTime),
- summary.getSize
- )
- } filterNot isInSubFolder(path)
- } toList
- })
- }
-
- class FileSystemStorage(executionContext: ExecutionContext) extends FileStorage {
- implicit private val execution = executionContext
-
- def upload(localSource: File, destination: Path): Future[Unit] = Future {
- checkSafeFileName(destination) {
- val destinationFile = destination.toFile
-
- if (destinationFile.getParentFile.exists() || destinationFile.getParentFile.mkdirs()) {
- if (localSource.renameTo(destinationFile)) ()
- else {
- throw new Exception(
- s"Failed to move file from `${localSource.getCanonicalPath}` to `${destinationFile.getCanonicalPath}`")
- }
- } else {
- throw new Exception(s"Failed to create parent directories for file `${destinationFile.getCanonicalPath}`")
- }
- }
- }
-
- def download(filePath: Path): OptionT[Future, File] =
- OptionT.optionT(Future {
- Option(new File(filePath.toString)).filter(file => file.exists() && file.isFile)
- })
-
- def delete(filePath: Path): Future[Unit] = Future {
- val file = new File(filePath.toString)
- if (file.delete()) ()
- else {
- throw new Exception(s"Failed to delete file $file" + (if (!file.exists()) ", file does not exist." else "."))
- }
- }
-
- def list(path: Path): ListT[Future, FileLink] =
- ListT.listT(Future {
- val file = new File(path.toString)
- if (file.isDirectory) {
- file.listFiles().toList.filter(_.isFile).map { file =>
- FileLink(Name[File](file.getName),
- Paths.get(file.getPath),
- Revision[File](file.hashCode.toString),
- Time(file.lastModified()),
- file.length())
- }
- } else List.empty[FileLink]
- })
- }
-}
diff --git a/src/main/scala/xyz/driver/core/file/FileSystemStorage.scala b/src/main/scala/xyz/driver/core/file/FileSystemStorage.scala
new file mode 100644
index 0000000..bfe6995
--- /dev/null
+++ b/src/main/scala/xyz/driver/core/file/FileSystemStorage.scala
@@ -0,0 +1,57 @@
+package xyz.driver.core.file
+
+import java.io.File
+import java.nio.file.{Path, Paths}
+
+import xyz.driver.core.{Name, Revision}
+import xyz.driver.core.time.Time
+
+import scala.concurrent.{ExecutionContext, Future}
+import scalaz.{ListT, OptionT}
+
+class FileSystemStorage(executionContext: ExecutionContext) extends FileStorage {
+ implicit private val execution = executionContext
+
+ def upload(localSource: File, destination: Path): Future[Unit] = Future {
+ checkSafeFileName(destination) {
+ val destinationFile = destination.toFile
+
+ if (destinationFile.getParentFile.exists() || destinationFile.getParentFile.mkdirs()) {
+ if (localSource.renameTo(destinationFile)) ()
+ else {
+ throw new Exception(
+ s"Failed to move file from `${localSource.getCanonicalPath}` to `${destinationFile.getCanonicalPath}`")
+ }
+ } else {
+ throw new Exception(s"Failed to create parent directories for file `${destinationFile.getCanonicalPath}`")
+ }
+ }
+ }
+
+ def download(filePath: Path): OptionT[Future, File] =
+ OptionT.optionT(Future {
+ Option(new File(filePath.toString)).filter(file => file.exists() && file.isFile)
+ })
+
+ def delete(filePath: Path): Future[Unit] = Future {
+ val file = new File(filePath.toString)
+ if (file.delete()) ()
+ else {
+ throw new Exception(s"Failed to delete file $file" + (if (!file.exists()) ", file does not exist." else "."))
+ }
+ }
+
+ def list(path: Path): ListT[Future, FileLink] =
+ ListT.listT(Future {
+ val file = new File(path.toString)
+ if (file.isDirectory) {
+ file.listFiles().toList.filter(_.isFile).map { file =>
+ FileLink(Name[File](file.getName),
+ Paths.get(file.getPath),
+ Revision[File](file.hashCode.toString),
+ Time(file.lastModified()),
+ file.length())
+ }
+ } else List.empty[FileLink]
+ })
+}
diff --git a/src/main/scala/xyz/driver/core/file/GcsStorage.scala b/src/main/scala/xyz/driver/core/file/GcsStorage.scala
new file mode 100644
index 0000000..6c2746e
--- /dev/null
+++ b/src/main/scala/xyz/driver/core/file/GcsStorage.scala
@@ -0,0 +1,82 @@
+package xyz.driver.core.file
+
+import java.io.{BufferedOutputStream, File, FileInputStream, FileOutputStream}
+import java.net.URL
+import java.nio.file.{Path, Paths}
+import java.util.concurrent.TimeUnit
+
+import com.google.cloud.storage.Storage.BlobListOption
+import com.google.cloud.storage._
+import xyz.driver.core.time.Time
+import xyz.driver.core.{Name, Revision, generators}
+
+import scala.collection.JavaConverters._
+import scala.concurrent.duration.Duration
+import scala.concurrent.{ExecutionContext, Future}
+import scalaz.{ListT, OptionT}
+
+class GcsStorage(storageClient: Storage, bucketName: Name[Bucket], executionContext: ExecutionContext)
+ extends SignedFileStorage {
+ implicit private val execution: ExecutionContext = executionContext
+
+ override def upload(localSource: File, destination: Path): Future[Unit] = Future {
+ checkSafeFileName(destination) {
+ val blobId = BlobId.of(bucketName.value, destination.toString)
+ def acl = Bucket.BlobWriteOption.predefinedAcl(Storage.PredefinedAcl.PUBLIC_READ)
+
+ storageClient.get(bucketName.value).create(blobId.getName, new FileInputStream(localSource), acl)
+ }
+ }
+
+ override def download(filePath: Path): OptionT[Future, File] = {
+ OptionT.optionT(Future {
+ Option(storageClient.get(bucketName.value, filePath.toString)).filterNot(_.getSize == 0).map {
+ blob =>
+ val tempDir = System.getProperty("java.io.tmpdir")
+ val randomFolderName = generators.nextUuid().toString
+ val tempDestinationFile = new File(Paths.get(tempDir, randomFolderName, filePath.toString).toString)
+
+ if (!tempDestinationFile.getParentFile.mkdirs()) {
+ throw new Exception(s"Failed to create temp directory to download file `$tempDestinationFile`")
+ } else {
+ val target = new BufferedOutputStream(new FileOutputStream(tempDestinationFile))
+ try target.write(blob.getContent())
+ finally target.close()
+ tempDestinationFile
+ }
+ }
+ })
+ }
+
+ override def delete(filePath: Path): Future[Unit] = Future {
+ storageClient.delete(BlobId.of(bucketName.value, filePath.toString))
+ }
+
+ override def list(path: Path): ListT[Future, FileLink] =
+ ListT.listT(Future {
+ val page = storageClient.list(
+ bucketName.value,
+ BlobListOption.currentDirectory(),
+ BlobListOption.prefix(path.toString)
+ )
+
+ page.iterateAll().asScala.map(blobToFileLink(path, _)).toList
+ })
+
+ protected def blobToFileLink(path: Path, blob: Blob): FileLink = {
+ FileLink(
+ Name(blob.getName),
+ Paths.get(path.toString, blob.getName),
+ Revision(blob.getGeneration.toString),
+ Time(blob.getUpdateTime),
+ blob.getSize
+ )
+ }
+
+ override def signedFileUrl(filePath: Path, duration: Duration): OptionT[Future, URL] =
+ OptionT.optionT(Future {
+ Option(storageClient.get(bucketName.value, filePath.toString)).filterNot(_.getSize == 0).map { blob =>
+ blob.signUrl(duration.toSeconds, TimeUnit.SECONDS)
+ }
+ })
+}
diff --git a/src/main/scala/xyz/driver/core/file/S3Storage.scala b/src/main/scala/xyz/driver/core/file/S3Storage.scala
new file mode 100644
index 0000000..933b01a
--- /dev/null
+++ b/src/main/scala/xyz/driver/core/file/S3Storage.scala
@@ -0,0 +1,66 @@
+package xyz.driver.core.file
+
+import java.io.File
+import java.nio.file.{Path, Paths}
+import java.util.UUID.randomUUID
+
+import com.amazonaws.services.s3.AmazonS3
+import com.amazonaws.services.s3.model.{Bucket, GetObjectRequest, ListObjectsV2Request}
+import xyz.driver.core.{Name, Revision}
+import xyz.driver.core.time.Time
+
+import scala.concurrent.{ExecutionContext, Future}
+import scalaz.{ListT, OptionT}
+
+class S3Storage(s3: AmazonS3, bucket: Name[Bucket], executionContext: ExecutionContext) extends FileStorage {
+ implicit private val execution = executionContext
+
+ def upload(localSource: File, destination: Path): Future[Unit] = Future {
+ checkSafeFileName(destination) {
+ val _ = s3.putObject(bucket.value, destination.toString, localSource).getETag
+ }
+ }
+
+ def download(filePath: Path): OptionT[Future, File] =
+ OptionT.optionT(Future {
+ val tempDir = System.getProperty("java.io.tmpdir")
+ val randomFolderName = randomUUID().toString
+ val tempDestinationFile = new File(Paths.get(tempDir, randomFolderName, filePath.toString).toString)
+
+ if (!tempDestinationFile.getParentFile.mkdirs()) {
+ throw new Exception(s"Failed to create temp directory to download file `$tempDestinationFile`")
+ } else {
+ Option(s3.getObject(new GetObjectRequest(bucket.value, filePath.toString), tempDestinationFile)).map { _ =>
+ tempDestinationFile
+ }
+ }
+ })
+
+ def delete(filePath: Path): Future[Unit] = Future {
+ s3.deleteObject(bucket.value, filePath.toString)
+ }
+
+ def list(path: Path): ListT[Future, FileLink] =
+ ListT.listT(Future {
+ import scala.collection.JavaConverters._
+ val req = new ListObjectsV2Request().withBucketName(bucket.value).withPrefix(path.toString).withMaxKeys(2)
+
+ def isInSubFolder(path: Path)(fileLink: FileLink) =
+ fileLink.location.toString.replace(path.toString + "/", "").contains("/")
+
+ Iterator.continually(s3.listObjectsV2(req)).takeWhile { result =>
+ req.setContinuationToken(result.getNextContinuationToken)
+ result.isTruncated
+ } flatMap { result =>
+ result.getObjectSummaries.asScala.toList.map { summary =>
+ FileLink(
+ Name[File](summary.getKey),
+ Paths.get(path.toString + "/" + summary.getKey),
+ Revision[File](summary.getETag),
+ Time(summary.getLastModified.getTime),
+ summary.getSize
+ )
+ } filterNot isInSubFolder(path)
+ } toList
+ })
+}
diff --git a/src/main/scala/xyz/driver/core/file/package.scala b/src/main/scala/xyz/driver/core/file/package.scala
new file mode 100644
index 0000000..9000894
--- /dev/null
+++ b/src/main/scala/xyz/driver/core/file/package.scala
@@ -0,0 +1,60 @@
+package xyz.driver.core
+
+import java.io.File
+import java.nio.file.Path
+
+import xyz.driver.core.time.Time
+
+import scala.concurrent.Future
+import scalaz.{ListT, OptionT}
+
+package file {
+
+ import java.net.URL
+
+ import scala.concurrent.duration.Duration
+
+ final case class FileLink(
+ name: Name[File],
+ location: Path,
+ revision: Revision[File],
+ lastModificationDate: Time,
+ fileSize: Long
+ )
+
+ trait FileService {
+
+ def getFileLink(id: Name[File]): FileLink
+
+ def getFile(fileLink: FileLink): File
+ }
+
+ trait FileStorage {
+
+ def upload(localSource: File, destination: Path): Future[Unit]
+
+ def download(filePath: Path): OptionT[Future, File]
+
+ def delete(filePath: Path): Future[Unit]
+
+ def list(path: Path): ListT[Future, FileLink]
+
+ /** List of characters to avoid in S3 (I would say file names in general)
+ *
+ * @see http://stackoverflow.com/questions/7116450/what-are-valid-s3-key-names-that-can-be-accessed-via-the-s3-rest-api
+ */
+ private val illegalChars = "\\^`><{}][#%~|&@:,$=+?; "
+
+ protected def checkSafeFileName[T](filePath: Path)(f: => T): T = {
+ filePath.toString.find(c => illegalChars.contains(c)) match {
+ case Some(illegalCharacter) =>
+ throw new IllegalArgumentException(s"File name cannot contain character `$illegalCharacter`")
+ case None => f
+ }
+ }
+ }
+
+ trait SignedFileStorage extends FileStorage {
+ def signedFileUrl(filePath: Path, duration: Duration): OptionT[Future, URL]
+ }
+}
diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala
index f1eab45..1db9d09 100644
--- a/src/main/scala/xyz/driver/core/rest.scala
+++ b/src/main/scala/xyz/driver/core/rest.scala
@@ -1,21 +1,24 @@
package xyz.driver.core
+import java.nio.file.{Files, Path}
+import java.security.spec.X509EncodedKeySpec
+import java.security.{KeyFactory, PublicKey}
+
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader}
import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected
-import akka.http.scaladsl.server.Directive0
-import com.typesafe.scalalogging.Logger
-import akka.http.scaladsl.unmarshalling.Unmarshal
-import akka.http.scaladsl.unmarshalling.Unmarshaller
+import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.Flow
import akka.util.ByteString
import com.github.swagger.akka.model._
import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService}
import com.typesafe.config.Config
+import com.typesafe.scalalogging.Logger
import io.swagger.models.Scheme
+import pdi.jwt.{Jwt, JwtAlgorithm}
import xyz.driver.core.auth._
import xyz.driver.core.time.provider.TimeProvider
@@ -33,7 +36,7 @@ package rest {
def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request))
def extractServiceContext(request: HttpRequest): ServiceRequestContext =
- ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
+ new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
def extractTrackingId(request: HttpRequest): String = {
request.headers
@@ -43,7 +46,8 @@ package rest {
def extractContextHeaders(request: HttpRequest): Map[String, String] = {
request.headers.filter { h =>
- h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader
+ h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader ||
+ h.name === ContextHeaders.PermissionsTokenHeader
} map { header =>
if (header.name === ContextHeaders.AuthenticationTokenHeader) {
header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim
@@ -91,14 +95,59 @@ package rest {
}
}
- final case class ServiceRequestContext(trackingId: String = generators.nextUuid().toString,
- contextHeaders: Map[String, String] = Map.empty[String, String]) {
-
+ class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString,
+ val contextHeaders: Map[String, String] = Map.empty[String, String]) {
def authToken: Option[AuthToken] =
contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply)
+ def permissionsToken: Option[PermissionsToken] =
+ contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply)
+
def withAuthToken(authToken: AuthToken): ServiceRequestContext =
- copy(contextHeaders = contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value))
+ new ServiceRequestContext(
+ trackingId,
+ contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value)
+ )
+
+ def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] =
+ new AuthorizedServiceRequestContext(
+ trackingId,
+ contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value),
+ user
+ )
+
+ override def hashCode(): Int =
+ Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode())
+
+ override def equals(obj: Any): Boolean = obj match {
+ case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders
+ case _ => false
+ }
+
+ override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)"
+ }
+
+ class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString,
+ override val contextHeaders: Map[String, String] =
+ Map.empty[String, String],
+ val authenticatedUser: U)
+ extends ServiceRequestContext {
+
+ def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] =
+ new AuthorizedServiceRequestContext[U](
+ trackingId,
+ contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value),
+ authenticatedUser)
+
+ override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode()
+
+ override def equals(obj: Any): Boolean = obj match {
+ case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser
+ case _ => false
+ }
+
+ override def toString: String =
+ s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)"
}
object ContextHeaders {
@@ -115,18 +164,78 @@ package rest {
val SetPermissionsTokenHeader = "set-permissions"
}
- trait Authorization {
- def userHasPermission(user: User, permission: Permission)(implicit ctx: ServiceRequestContext): Future[Boolean]
+ final case class AuthorizationResult(authorized: Boolean, token: Option[PermissionsToken])
+ object AuthorizationResult {
+ val unauthorized: AuthorizationResult = AuthorizationResult(authorized = false, None)
+ }
+
+ trait Authorization[U <: User] {
+ def userHasPermissions(user: U, permissions: Seq[Permission])(
+ implicit ctx: ServiceRequestContext): Future[AuthorizationResult]
+ }
+
+ class AlwaysAllowAuthorization[U <: User](implicit execution: ExecutionContext) extends Authorization[U] {
+ override def userHasPermissions(user: U, permissions: Seq[Permission])(
+ implicit ctx: ServiceRequestContext): Future[AuthorizationResult] =
+ Future.successful(AuthorizationResult(authorized = true, ctx.permissionsToken))
}
- class AlwaysAllowAuthorization extends Authorization {
- override def userHasPermission(user: User, permission: Permission)(
- implicit ctx: ServiceRequestContext): Future[Boolean] = {
- Future.successful(true)
+ class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] {
+ override def userHasPermissions(user: U, permissions: Seq[Permission])(
+ implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = {
+ import spray.json._
+
+ def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] =
+ tokenObject.fields.get("permissions").collect {
+ case JsObject(fields) =>
+ fields.collect {
+ case (key, JsBoolean(value)) => key -> value
+ }
+ }
+
+ val result = for {
+ token <- ctx.permissionsToken
+ jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption
+ jwtJson = jwt.parseJson.asJsObject
+
+ // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None
+ _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(())
+ _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(())
+
+ permissionsMap <- extractPermissionsFromTokenJSON(jwtJson)
+
+ authorized = permissions.forall(p => permissionsMap.get(p.toString).contains(true))
+ } yield AuthorizationResult(authorized, Some(token))
+
+ Future.successful(result.getOrElse(AuthorizationResult.unauthorized))
+ }
+ }
+
+ object CachedTokenAuthorization {
+ def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = {
+ lazy val publicKey: PublicKey = {
+ val publicKeyBytes = Files.readAllBytes(publicKeyFile)
+ val spec = new X509EncodedKeySpec(publicKeyBytes)
+ KeyFactory.getInstance("RSA").generatePublic(spec)
+ }
+ new CachedTokenAuthorization[U](publicKey, issuer)
+ }
+ }
+
+ class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext)
+ extends Authorization[U] {
+
+ override def userHasPermissions(user: U, permissions: Seq[Permission])(
+ implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = {
+ authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) {
+ (authResult, authorization) =>
+ if (authResult.authorized) Future.successful(authResult)
+ else authorization.userHasPermissions(user, permissions)
+ }
}
}
- abstract class AuthProvider[U <: User](val authorization: Authorization, log: Logger)(
+ abstract class AuthProvider[U <: User](val authorization: Authorization[U], log: Logger)(
implicit execution: ExecutionContext) {
import akka.http.scaladsl.server._
@@ -142,43 +251,30 @@ package rest {
def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U]
/**
- * Specific implementation can verify session expiration and single sign out
- * to verify if session is still valid
- */
- def isSessionValid(user: U)(implicit ctx: ServiceRequestContext): Future[Boolean]
-
- /**
* Verifies if request is authenticated and authorized to have `permissions`
*/
- def authorize(permissions: Permission*): Directive1[U] = {
+ def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = {
serviceContext flatMap { ctx =>
- onComplete(authenticatedUser(ctx).run flatMap { userOption =>
- userOption.traverseM[Future, (U, Boolean)] { user =>
- isSessionValid(user)(ctx).flatMap { sessionValid =>
- if (sessionValid) {
- permissions.toList
- .traverse[Future, Boolean](authorization.userHasPermission(user, _)(ctx))
- .map(results => Option(user -> results.forall(identity)))
- } else {
- Future.successful(Option.empty[(U, Boolean)])
- }
- }
- }
- }).flatMap {
- case Success(Some((user, authorizationResult))) =>
- if (authorizationResult) provide(user)
- else {
- val challenge =
- HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}")
- log.warn(s"User $user does not have the required permissions: ${permissions.mkString(", ")}")
- reject(AuthenticationFailedRejection(CredentialsRejected, challenge))
- }
-
+ onComplete {
+ (for {
+ authToken <- OptionT.optionT(Future.successful(ctx.authToken))
+ user <- authenticatedUser(ctx)
+ authCtx = ctx.withAuthenticatedUser(authToken, user)
+ authorizationResult <- authorization.userHasPermissions(user, permissions)(authCtx).toOptionT
+ cachedPermissionsAuthCtx = authorizationResult.token.fold(authCtx)(authCtx.withPermissionsToken)
+ } yield (cachedPermissionsAuthCtx, authorizationResult.authorized)).run
+ } flatMap {
+ case Success(Some((authCtx, true))) => provide(authCtx)
+ case Success(Some((authCtx, false))) =>
+ val challenge =
+ HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}")
+ log.warn(
+ s"User ${authCtx.authenticatedUser} does not have the required permissions: ${permissions.mkString(", ")}")
+ reject(AuthenticationFailedRejection(CredentialsRejected, challenge))
case Success(None) =>
log.warn(
s"Wasn't able to find authenticated user for the token provided to verify ${permissions.mkString(", ")}")
reject(ValidationRejection(s"Wasn't able to find authenticated user for the token provided"))
-
case Failure(t) =>
log.warn(s"Wasn't able to verify token for authenticated user to verify ${permissions.mkString(", ")}", t)
reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t)))
@@ -193,7 +289,6 @@ package rest {
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import spray.json._
- import DefaultJsonProtocol._
protected implicit val exec: ExecutionContext
protected implicit val materializer: ActorMaterializer
@@ -217,10 +312,7 @@ package rest {
protected def jsonEntity(json: JsValue): RequestEntity =
HttpEntity(ContentTypes.`application/json`, json.compactPrint)
- protected def get(baseUri: Uri, path: String) =
- HttpRequest(HttpMethods.GET, endpointUri(baseUri, path))
-
- protected def get(baseUri: Uri, path: String, query: Map[String, String]) =
+ protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) =
HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query))
protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) =
@@ -235,8 +327,8 @@ package rest {
protected def endpointUri(baseUri: Uri, path: String) =
baseUri.withPath(Uri.Path(path))
- protected def endpointUri(baseUri: Uri, path: String, query: Map[String, String]) =
- baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query))
+ protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]) =
+ baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*))
}
trait ServiceTransport {