diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main/scala/xyz/driver/core/app.scala | 11 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/auth.scala | 1 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/file.scala | 155 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/file/FileSystemStorage.scala | 57 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/file/GcsStorage.scala | 82 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/file/S3Storage.scala | 66 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/file/package.scala | 60 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/rest.scala | 200 | ||||
-rw-r--r-- | src/test/scala/xyz/driver/core/AuthTest.scala | 69 | ||||
-rw-r--r-- | src/test/scala/xyz/driver/core/DateTest.scala | 2 | ||||
-rw-r--r-- | src/test/scala/xyz/driver/core/FileTest.scala | 71 |
11 files changed, 544 insertions, 230 deletions
diff --git a/src/main/scala/xyz/driver/core/app.scala b/src/main/scala/xyz/driver/core/app.scala index f7731e3..e35c300 100644 --- a/src/main/scala/xyz/driver/core/app.scala +++ b/src/main/scala/xyz/driver/core/app.scala @@ -26,6 +26,8 @@ import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} import scala.compat.Platform.ConcurrentModificationException import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext, Future} +import scalaz.Scalaz.stringInstance +import scalaz.syntax.equal._ object app { @@ -59,6 +61,9 @@ object app { } } + private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = + request.headers.find(_.name().toLowerCase === headerName).map(_.value()) + protected def bindHttp(modules: Seq[Module]): Unit = { val serviceTypes = modules.flatMap(_.routeTypes) val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme), version, actorSystem, serviceTypes, config) @@ -74,6 +79,12 @@ object app { val trackingId = rest.extractTrackingId(ctx.request) MDC.put("trackingId", trackingId) MDC.put("origin", origin) + MDC.put("xForwardedFor", + extractHeader(ctx.request)("x-forwarded-for") + .orElse(extractHeader(ctx.request)("x_forwarded_for")) + .getOrElse("unknown")) + MDC.put("remoteAddress", extractHeader(ctx.request)("remote-address").getOrElse("unknown")) + MDC.put("userAgent", extractHeader(ctx.request)("user-agent").getOrElse("unknown")) MDC.put("ip", ip.toOption.map(_.getHostAddress).getOrElse("unknown")) def requestLogging: Future[Unit] = Future { diff --git a/src/main/scala/xyz/driver/core/auth.scala b/src/main/scala/xyz/driver/core/auth.scala index f9a1a57..5dea2db 100644 --- a/src/main/scala/xyz/driver/core/auth.scala +++ b/src/main/scala/xyz/driver/core/auth.scala @@ -23,6 +23,7 @@ object auth { final case class AuthToken(value: String) final case class RefreshToken(value: String) + final case class PermissionsToken(value: String) final case class PasswordHash(value: String) diff --git a/src/main/scala/xyz/driver/core/file.scala b/src/main/scala/xyz/driver/core/file.scala deleted file mode 100644 index dcc4b87..0000000 --- a/src/main/scala/xyz/driver/core/file.scala +++ /dev/null @@ -1,155 +0,0 @@ -package xyz.driver.core - -import java.io.File -import java.nio.file.{Path, Paths} -import java.util.UUID._ - -import com.amazonaws.services.s3.AmazonS3 -import com.amazonaws.services.s3.model.{Bucket, GetObjectRequest, ListObjectsV2Request} -import xyz.driver.core.time.Time - -import scala.concurrent.{ExecutionContext, Future} -import scalaz.{ListT, OptionT} - -object file { - - final case class FileLink( - name: Name[File], - location: Path, - revision: Revision[File], - lastModificationDate: Time, - fileSize: Long - ) - - trait FileService { - - def getFileLink(id: Name[File]): FileLink - - def getFile(fileLink: FileLink): File - } - - trait FileStorage { - - def upload(localSource: File, destination: Path): Future[Unit] - - def download(filePath: Path): OptionT[Future, File] - - def delete(filePath: Path): Future[Unit] - - def list(path: Path): ListT[Future, FileLink] - - /** List of characters to avoid in S3 (I would say file names in general) - * - * @see http://stackoverflow.com/questions/7116450/what-are-valid-s3-key-names-that-can-be-accessed-via-the-s3-rest-api - */ - private val illegalChars = "\\^`><{}][#%~|&@:,$=+?; " - - protected def checkSafeFileName[T](filePath: Path)(f: => T): T = { - filePath.toString.find(c => illegalChars.contains(c)) match { - case Some(illegalCharacter) => - throw new IllegalArgumentException(s"File name cannot contain character `$illegalCharacter`") - case None => f - } - } - } - - class S3Storage(s3: AmazonS3, bucket: Name[Bucket], executionContext: ExecutionContext) extends FileStorage { - implicit private val execution = executionContext - - def upload(localSource: File, destination: Path): Future[Unit] = Future { - checkSafeFileName(destination) { - val _ = s3.putObject(bucket.value, destination.toString, localSource).getETag - } - } - - def download(filePath: Path): OptionT[Future, File] = - OptionT.optionT(Future { - val tempDir = System.getProperty("java.io.tmpdir") - val randomFolderName = randomUUID().toString - val tempDestinationFile = new File(Paths.get(tempDir, randomFolderName, filePath.toString).toString) - - if (!tempDestinationFile.getParentFile.mkdirs()) { - throw new Exception(s"Failed to create temp directory to download file `$tempDestinationFile`") - } else { - Option(s3.getObject(new GetObjectRequest(bucket.value, filePath.toString), tempDestinationFile)).map { _ => - tempDestinationFile - } - } - }) - - def delete(filePath: Path): Future[Unit] = Future { - s3.deleteObject(bucket.value, filePath.toString) - } - - def list(path: Path): ListT[Future, FileLink] = - ListT.listT(Future { - import scala.collection.JavaConverters._ - val req = new ListObjectsV2Request().withBucketName(bucket.value).withPrefix(path.toString).withMaxKeys(2) - - def isInSubFolder(path: Path)(fileLink: FileLink) = - fileLink.location.toString.replace(path.toString + "/", "").contains("/") - - Iterator.continually(s3.listObjectsV2(req)).takeWhile { result => - req.setContinuationToken(result.getNextContinuationToken) - result.isTruncated - } flatMap { result => - result.getObjectSummaries.asScala.toList.map { summary => - FileLink( - Name[File](summary.getKey), - Paths.get(path.toString + "/" + summary.getKey), - Revision[File](summary.getETag), - Time(summary.getLastModified.getTime), - summary.getSize - ) - } filterNot isInSubFolder(path) - } toList - }) - } - - class FileSystemStorage(executionContext: ExecutionContext) extends FileStorage { - implicit private val execution = executionContext - - def upload(localSource: File, destination: Path): Future[Unit] = Future { - checkSafeFileName(destination) { - val destinationFile = destination.toFile - - if (destinationFile.getParentFile.exists() || destinationFile.getParentFile.mkdirs()) { - if (localSource.renameTo(destinationFile)) () - else { - throw new Exception( - s"Failed to move file from `${localSource.getCanonicalPath}` to `${destinationFile.getCanonicalPath}`") - } - } else { - throw new Exception(s"Failed to create parent directories for file `${destinationFile.getCanonicalPath}`") - } - } - } - - def download(filePath: Path): OptionT[Future, File] = - OptionT.optionT(Future { - Option(new File(filePath.toString)).filter(file => file.exists() && file.isFile) - }) - - def delete(filePath: Path): Future[Unit] = Future { - val file = new File(filePath.toString) - if (file.delete()) () - else { - throw new Exception(s"Failed to delete file $file" + (if (!file.exists()) ", file does not exist." else ".")) - } - } - - def list(path: Path): ListT[Future, FileLink] = - ListT.listT(Future { - val file = new File(path.toString) - if (file.isDirectory) { - file.listFiles().toList.filter(_.isFile).map { file => - FileLink(Name[File](file.getName), - Paths.get(file.getPath), - Revision[File](file.hashCode.toString), - Time(file.lastModified()), - file.length()) - } - } else List.empty[FileLink] - }) - } -} diff --git a/src/main/scala/xyz/driver/core/file/FileSystemStorage.scala b/src/main/scala/xyz/driver/core/file/FileSystemStorage.scala new file mode 100644 index 0000000..bfe6995 --- /dev/null +++ b/src/main/scala/xyz/driver/core/file/FileSystemStorage.scala @@ -0,0 +1,57 @@ +package xyz.driver.core.file + +import java.io.File +import java.nio.file.{Path, Paths} + +import xyz.driver.core.{Name, Revision} +import xyz.driver.core.time.Time + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +class FileSystemStorage(executionContext: ExecutionContext) extends FileStorage { + implicit private val execution = executionContext + + def upload(localSource: File, destination: Path): Future[Unit] = Future { + checkSafeFileName(destination) { + val destinationFile = destination.toFile + + if (destinationFile.getParentFile.exists() || destinationFile.getParentFile.mkdirs()) { + if (localSource.renameTo(destinationFile)) () + else { + throw new Exception( + s"Failed to move file from `${localSource.getCanonicalPath}` to `${destinationFile.getCanonicalPath}`") + } + } else { + throw new Exception(s"Failed to create parent directories for file `${destinationFile.getCanonicalPath}`") + } + } + } + + def download(filePath: Path): OptionT[Future, File] = + OptionT.optionT(Future { + Option(new File(filePath.toString)).filter(file => file.exists() && file.isFile) + }) + + def delete(filePath: Path): Future[Unit] = Future { + val file = new File(filePath.toString) + if (file.delete()) () + else { + throw new Exception(s"Failed to delete file $file" + (if (!file.exists()) ", file does not exist." else ".")) + } + } + + def list(path: Path): ListT[Future, FileLink] = + ListT.listT(Future { + val file = new File(path.toString) + if (file.isDirectory) { + file.listFiles().toList.filter(_.isFile).map { file => + FileLink(Name[File](file.getName), + Paths.get(file.getPath), + Revision[File](file.hashCode.toString), + Time(file.lastModified()), + file.length()) + } + } else List.empty[FileLink] + }) +} diff --git a/src/main/scala/xyz/driver/core/file/GcsStorage.scala b/src/main/scala/xyz/driver/core/file/GcsStorage.scala new file mode 100644 index 0000000..6c2746e --- /dev/null +++ b/src/main/scala/xyz/driver/core/file/GcsStorage.scala @@ -0,0 +1,82 @@ +package xyz.driver.core.file + +import java.io.{BufferedOutputStream, File, FileInputStream, FileOutputStream} +import java.net.URL +import java.nio.file.{Path, Paths} +import java.util.concurrent.TimeUnit + +import com.google.cloud.storage.Storage.BlobListOption +import com.google.cloud.storage._ +import xyz.driver.core.time.Time +import xyz.driver.core.{Name, Revision, generators} + +import scala.collection.JavaConverters._ +import scala.concurrent.duration.Duration +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +class GcsStorage(storageClient: Storage, bucketName: Name[Bucket], executionContext: ExecutionContext) + extends SignedFileStorage { + implicit private val execution: ExecutionContext = executionContext + + override def upload(localSource: File, destination: Path): Future[Unit] = Future { + checkSafeFileName(destination) { + val blobId = BlobId.of(bucketName.value, destination.toString) + def acl = Bucket.BlobWriteOption.predefinedAcl(Storage.PredefinedAcl.PUBLIC_READ) + + storageClient.get(bucketName.value).create(blobId.getName, new FileInputStream(localSource), acl) + } + } + + override def download(filePath: Path): OptionT[Future, File] = { + OptionT.optionT(Future { + Option(storageClient.get(bucketName.value, filePath.toString)).filterNot(_.getSize == 0).map { + blob => + val tempDir = System.getProperty("java.io.tmpdir") + val randomFolderName = generators.nextUuid().toString + val tempDestinationFile = new File(Paths.get(tempDir, randomFolderName, filePath.toString).toString) + + if (!tempDestinationFile.getParentFile.mkdirs()) { + throw new Exception(s"Failed to create temp directory to download file `$tempDestinationFile`") + } else { + val target = new BufferedOutputStream(new FileOutputStream(tempDestinationFile)) + try target.write(blob.getContent()) + finally target.close() + tempDestinationFile + } + } + }) + } + + override def delete(filePath: Path): Future[Unit] = Future { + storageClient.delete(BlobId.of(bucketName.value, filePath.toString)) + } + + override def list(path: Path): ListT[Future, FileLink] = + ListT.listT(Future { + val page = storageClient.list( + bucketName.value, + BlobListOption.currentDirectory(), + BlobListOption.prefix(path.toString) + ) + + page.iterateAll().asScala.map(blobToFileLink(path, _)).toList + }) + + protected def blobToFileLink(path: Path, blob: Blob): FileLink = { + FileLink( + Name(blob.getName), + Paths.get(path.toString, blob.getName), + Revision(blob.getGeneration.toString), + Time(blob.getUpdateTime), + blob.getSize + ) + } + + override def signedFileUrl(filePath: Path, duration: Duration): OptionT[Future, URL] = + OptionT.optionT(Future { + Option(storageClient.get(bucketName.value, filePath.toString)).filterNot(_.getSize == 0).map { blob => + blob.signUrl(duration.toSeconds, TimeUnit.SECONDS) + } + }) +} diff --git a/src/main/scala/xyz/driver/core/file/S3Storage.scala b/src/main/scala/xyz/driver/core/file/S3Storage.scala new file mode 100644 index 0000000..933b01a --- /dev/null +++ b/src/main/scala/xyz/driver/core/file/S3Storage.scala @@ -0,0 +1,66 @@ +package xyz.driver.core.file + +import java.io.File +import java.nio.file.{Path, Paths} +import java.util.UUID.randomUUID + +import com.amazonaws.services.s3.AmazonS3 +import com.amazonaws.services.s3.model.{Bucket, GetObjectRequest, ListObjectsV2Request} +import xyz.driver.core.{Name, Revision} +import xyz.driver.core.time.Time + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +class S3Storage(s3: AmazonS3, bucket: Name[Bucket], executionContext: ExecutionContext) extends FileStorage { + implicit private val execution = executionContext + + def upload(localSource: File, destination: Path): Future[Unit] = Future { + checkSafeFileName(destination) { + val _ = s3.putObject(bucket.value, destination.toString, localSource).getETag + } + } + + def download(filePath: Path): OptionT[Future, File] = + OptionT.optionT(Future { + val tempDir = System.getProperty("java.io.tmpdir") + val randomFolderName = randomUUID().toString + val tempDestinationFile = new File(Paths.get(tempDir, randomFolderName, filePath.toString).toString) + + if (!tempDestinationFile.getParentFile.mkdirs()) { + throw new Exception(s"Failed to create temp directory to download file `$tempDestinationFile`") + } else { + Option(s3.getObject(new GetObjectRequest(bucket.value, filePath.toString), tempDestinationFile)).map { _ => + tempDestinationFile + } + } + }) + + def delete(filePath: Path): Future[Unit] = Future { + s3.deleteObject(bucket.value, filePath.toString) + } + + def list(path: Path): ListT[Future, FileLink] = + ListT.listT(Future { + import scala.collection.JavaConverters._ + val req = new ListObjectsV2Request().withBucketName(bucket.value).withPrefix(path.toString).withMaxKeys(2) + + def isInSubFolder(path: Path)(fileLink: FileLink) = + fileLink.location.toString.replace(path.toString + "/", "").contains("/") + + Iterator.continually(s3.listObjectsV2(req)).takeWhile { result => + req.setContinuationToken(result.getNextContinuationToken) + result.isTruncated + } flatMap { result => + result.getObjectSummaries.asScala.toList.map { summary => + FileLink( + Name[File](summary.getKey), + Paths.get(path.toString + "/" + summary.getKey), + Revision[File](summary.getETag), + Time(summary.getLastModified.getTime), + summary.getSize + ) + } filterNot isInSubFolder(path) + } toList + }) +} diff --git a/src/main/scala/xyz/driver/core/file/package.scala b/src/main/scala/xyz/driver/core/file/package.scala new file mode 100644 index 0000000..9000894 --- /dev/null +++ b/src/main/scala/xyz/driver/core/file/package.scala @@ -0,0 +1,60 @@ +package xyz.driver.core + +import java.io.File +import java.nio.file.Path + +import xyz.driver.core.time.Time + +import scala.concurrent.Future +import scalaz.{ListT, OptionT} + +package file { + + import java.net.URL + + import scala.concurrent.duration.Duration + + final case class FileLink( + name: Name[File], + location: Path, + revision: Revision[File], + lastModificationDate: Time, + fileSize: Long + ) + + trait FileService { + + def getFileLink(id: Name[File]): FileLink + + def getFile(fileLink: FileLink): File + } + + trait FileStorage { + + def upload(localSource: File, destination: Path): Future[Unit] + + def download(filePath: Path): OptionT[Future, File] + + def delete(filePath: Path): Future[Unit] + + def list(path: Path): ListT[Future, FileLink] + + /** List of characters to avoid in S3 (I would say file names in general) + * + * @see http://stackoverflow.com/questions/7116450/what-are-valid-s3-key-names-that-can-be-accessed-via-the-s3-rest-api + */ + private val illegalChars = "\\^`><{}][#%~|&@:,$=+?; " + + protected def checkSafeFileName[T](filePath: Path)(f: => T): T = { + filePath.toString.find(c => illegalChars.contains(c)) match { + case Some(illegalCharacter) => + throw new IllegalArgumentException(s"File name cannot contain character `$illegalCharacter`") + case None => f + } + } + } + + trait SignedFileStorage extends FileStorage { + def signedFileUrl(filePath: Path, duration: Duration): OptionT[Future, URL] + } +} diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index f1eab45..1db9d09 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -1,21 +1,24 @@ package xyz.driver.core +import java.nio.file.{Files, Path} +import java.security.spec.X509EncodedKeySpec +import java.security.{KeyFactory, PublicKey} + import akka.actor.ActorSystem import akka.http.scaladsl.Http import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader} import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected -import akka.http.scaladsl.server.Directive0 -import com.typesafe.scalalogging.Logger -import akka.http.scaladsl.unmarshalling.Unmarshal -import akka.http.scaladsl.unmarshalling.Unmarshaller +import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} import akka.stream.ActorMaterializer import akka.stream.scaladsl.Flow import akka.util.ByteString import com.github.swagger.akka.model._ import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService} import com.typesafe.config.Config +import com.typesafe.scalalogging.Logger import io.swagger.models.Scheme +import pdi.jwt.{Jwt, JwtAlgorithm} import xyz.driver.core.auth._ import xyz.driver.core.time.provider.TimeProvider @@ -33,7 +36,7 @@ package rest { def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) def extractServiceContext(request: HttpRequest): ServiceRequestContext = - ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) + new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) def extractTrackingId(request: HttpRequest): String = { request.headers @@ -43,7 +46,8 @@ package rest { def extractContextHeaders(request: HttpRequest): Map[String, String] = { request.headers.filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || + h.name === ContextHeaders.PermissionsTokenHeader } map { header => if (header.name === ContextHeaders.AuthenticationTokenHeader) { header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim @@ -91,14 +95,59 @@ package rest { } } - final case class ServiceRequestContext(trackingId: String = generators.nextUuid().toString, - contextHeaders: Map[String, String] = Map.empty[String, String]) { - + class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString, + val contextHeaders: Map[String, String] = Map.empty[String, String]) { def authToken: Option[AuthToken] = contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) + def permissionsToken: Option[PermissionsToken] = + contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) + def withAuthToken(authToken: AuthToken): ServiceRequestContext = - copy(contextHeaders = contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value)) + new ServiceRequestContext( + trackingId, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) + ) + + def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext( + trackingId, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), + user + ) + + override def hashCode(): Int = + Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) + + override def equals(obj: Any): Boolean = obj match { + case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders + case _ => false + } + + override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" + } + + class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, + override val contextHeaders: Map[String, String] = + Map.empty[String, String], + val authenticatedUser: U) + extends ServiceRequestContext { + + def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext[U]( + trackingId, + contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), + authenticatedUser) + + override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() + + override def equals(obj: Any): Boolean = obj match { + case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser + case _ => false + } + + override def toString: String = + s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" } object ContextHeaders { @@ -115,18 +164,78 @@ package rest { val SetPermissionsTokenHeader = "set-permissions" } - trait Authorization { - def userHasPermission(user: User, permission: Permission)(implicit ctx: ServiceRequestContext): Future[Boolean] + final case class AuthorizationResult(authorized: Boolean, token: Option[PermissionsToken]) + object AuthorizationResult { + val unauthorized: AuthorizationResult = AuthorizationResult(authorized = false, None) + } + + trait Authorization[U <: User] { + def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] + } + + class AlwaysAllowAuthorization[U <: User](implicit execution: ExecutionContext) extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = + Future.successful(AuthorizationResult(authorized = true, ctx.permissionsToken)) } - class AlwaysAllowAuthorization extends Authorization { - override def userHasPermission(user: User, permission: Permission)( - implicit ctx: ServiceRequestContext): Future[Boolean] = { - Future.successful(true) + class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + import spray.json._ + + def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] = + tokenObject.fields.get("permissions").collect { + case JsObject(fields) => + fields.collect { + case (key, JsBoolean(value)) => key -> value + } + } + + val result = for { + token <- ctx.permissionsToken + jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption + jwtJson = jwt.parseJson.asJsObject + + // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None + _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) + _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) + + permissionsMap <- extractPermissionsFromTokenJSON(jwtJson) + + authorized = permissions.forall(p => permissionsMap.get(p.toString).contains(true)) + } yield AuthorizationResult(authorized, Some(token)) + + Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) + } + } + + object CachedTokenAuthorization { + def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = { + lazy val publicKey: PublicKey = { + val publicKeyBytes = Files.readAllBytes(publicKeyFile) + val spec = new X509EncodedKeySpec(publicKeyBytes) + KeyFactory.getInstance("RSA").generatePublic(spec) + } + new CachedTokenAuthorization[U](publicKey, issuer) + } + } + + class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) + extends Authorization[U] { + + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) { + (authResult, authorization) => + if (authResult.authorized) Future.successful(authResult) + else authorization.userHasPermissions(user, permissions) + } } } - abstract class AuthProvider[U <: User](val authorization: Authorization, log: Logger)( + abstract class AuthProvider[U <: User](val authorization: Authorization[U], log: Logger)( implicit execution: ExecutionContext) { import akka.http.scaladsl.server._ @@ -142,43 +251,30 @@ package rest { def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] /** - * Specific implementation can verify session expiration and single sign out - * to verify if session is still valid - */ - def isSessionValid(user: U)(implicit ctx: ServiceRequestContext): Future[Boolean] - - /** * Verifies if request is authenticated and authorized to have `permissions` */ - def authorize(permissions: Permission*): Directive1[U] = { + def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { serviceContext flatMap { ctx => - onComplete(authenticatedUser(ctx).run flatMap { userOption => - userOption.traverseM[Future, (U, Boolean)] { user => - isSessionValid(user)(ctx).flatMap { sessionValid => - if (sessionValid) { - permissions.toList - .traverse[Future, Boolean](authorization.userHasPermission(user, _)(ctx)) - .map(results => Option(user -> results.forall(identity))) - } else { - Future.successful(Option.empty[(U, Boolean)]) - } - } - } - }).flatMap { - case Success(Some((user, authorizationResult))) => - if (authorizationResult) provide(user) - else { - val challenge = - HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}") - log.warn(s"User $user does not have the required permissions: ${permissions.mkString(", ")}") - reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) - } - + onComplete { + (for { + authToken <- OptionT.optionT(Future.successful(ctx.authToken)) + user <- authenticatedUser(ctx) + authCtx = ctx.withAuthenticatedUser(authToken, user) + authorizationResult <- authorization.userHasPermissions(user, permissions)(authCtx).toOptionT + cachedPermissionsAuthCtx = authorizationResult.token.fold(authCtx)(authCtx.withPermissionsToken) + } yield (cachedPermissionsAuthCtx, authorizationResult.authorized)).run + } flatMap { + case Success(Some((authCtx, true))) => provide(authCtx) + case Success(Some((authCtx, false))) => + val challenge = + HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}") + log.warn( + s"User ${authCtx.authenticatedUser} does not have the required permissions: ${permissions.mkString(", ")}") + reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) case Success(None) => log.warn( s"Wasn't able to find authenticated user for the token provided to verify ${permissions.mkString(", ")}") reject(ValidationRejection(s"Wasn't able to find authenticated user for the token provided")) - case Failure(t) => log.warn(s"Wasn't able to verify token for authenticated user to verify ${permissions.mkString(", ")}", t) reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t))) @@ -193,7 +289,6 @@ package rest { import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ import spray.json._ - import DefaultJsonProtocol._ protected implicit val exec: ExecutionContext protected implicit val materializer: ActorMaterializer @@ -217,10 +312,7 @@ package rest { protected def jsonEntity(json: JsValue): RequestEntity = HttpEntity(ContentTypes.`application/json`, json.compactPrint) - protected def get(baseUri: Uri, path: String) = - HttpRequest(HttpMethods.GET, endpointUri(baseUri, path)) - - protected def get(baseUri: Uri, path: String, query: Map[String, String]) = + protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query)) protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) = @@ -235,8 +327,8 @@ package rest { protected def endpointUri(baseUri: Uri, path: String) = baseUri.withPath(Uri.Path(path)) - protected def endpointUri(baseUri: Uri, path: String, query: Map[String, String]) = - baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query)) + protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]) = + baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*)) } trait ServiceTransport { diff --git a/src/test/scala/xyz/driver/core/AuthTest.scala b/src/test/scala/xyz/driver/core/AuthTest.scala index ad8cec8..bf776df 100644 --- a/src/test/scala/xyz/driver/core/AuthTest.scala +++ b/src/test/scala/xyz/driver/core/AuthTest.scala @@ -7,34 +7,47 @@ import akka.http.scaladsl.server._ import akka.http.scaladsl.testkit.ScalatestRouteTest import org.scalatest.mock.MockitoSugar import org.scalatest.{FlatSpec, Matchers} +import pdi.jwt.{Jwt, JwtAlgorithm} import xyz.driver.core.auth._ import xyz.driver.core.logging._ -import xyz.driver.core.rest.{AuthProvider, Authorization, ServiceRequestContext} +import xyz.driver.core.rest._ import scala.concurrent.Future import scalaz.OptionT class AuthTest extends FlatSpec with Matchers with MockitoSugar with ScalatestRouteTest { - case object TestRoleAllowedPermission extends Permission - case object TestRoleNotAllowedPermission extends Permission + case object TestRoleAllowedPermission extends Permission + case object TestRoleAllowedByTokenPermission extends Permission + case object TestRoleNotAllowedPermission extends Permission val TestRole = Role(Id("1"), Name("testRole")) - implicit val exec = scala.concurrent.ExecutionContext.global + val (publicKey, privateKey) = { + import java.security.KeyPairGenerator - val authorization: Authorization = new Authorization { - override def userHasPermission(user: User, permission: Permission)( - implicit ctx: ServiceRequestContext): Future[Boolean] = { - Future.successful(permission === TestRoleAllowedPermission) + val keygen = KeyPairGenerator.getInstance("RSA") + keygen.initialize(2048) + + val keyPair = keygen.generateKeyPair() + (keyPair.getPublic, keyPair.getPrivate) + } + + val basicAuthorization: Authorization[User] = new Authorization[User] { + + override def userHasPermissions(user: User, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + val authorized = permissions.forall(_ === TestRoleAllowedPermission) + Future.successful(AuthorizationResult(authorized, ctx.permissionsToken)) } } - val authStatusService = new AuthProvider[User](authorization, NoLogger) { + val tokenIssuer = "users" + val tokenAuthorization = new CachedTokenAuthorization[User](publicKey, tokenIssuer) - override def isSessionValid(user: User)(implicit ctx: ServiceRequestContext): Future[Boolean] = - Future.successful(true) + val authorization = new ChainedAuthorization[User](tokenAuthorization, basicAuthorization) + val authStatusService = new AuthProvider[User](authorization, NoLogger) { override def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, User] = OptionT.optionT[Future] { if (ctx.contextHeaders.keySet.contains(AuthProvider.AuthenticationTokenHeader)) { @@ -47,7 +60,7 @@ class AuthTest extends FlatSpec with Matchers with MockitoSugar with ScalatestRo import authStatusService._ - "'authorize' directive" should "throw error is auth token is not in the request" in { + "'authorize' directive" should "throw error if auth token is not in the request" in { Get("/naive/attempt") ~> authorize(TestRoleAllowedPermission) { user => @@ -59,7 +72,7 @@ class AuthTest extends FlatSpec with Matchers with MockitoSugar with ScalatestRo } } - it should "throw error is authorized user is not having the requested permission" in { + it should "throw error if authorized user does not have the requested permission" in { val referenceAuthToken = AuthToken("I am a test role's token") @@ -85,12 +98,36 @@ class AuthTest extends FlatSpec with Matchers with MockitoSugar with ScalatestRo Get("/valid/attempt/?a=2&b=5").addHeader( RawHeader(AuthProvider.AuthenticationTokenHeader, referenceAuthToken.value) ) ~> - authorize(TestRoleAllowedPermission) { user => - complete("Alright, user \"" + user.id + "\" is authorized") + authorize(TestRoleAllowedPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized" + } + } + + it should "authorize permission found in permissions token" in { + import spray.json._ + + val claim = JsObject( + Map( + "iss" -> JsString(tokenIssuer), + "sub" -> JsString("1"), + "permissions" -> JsObject(Map(TestRoleAllowedByTokenPermission.toString -> JsBoolean(true))) + )).prettyPrint + val permissionsToken = PermissionsToken(Jwt.encode(claim, privateKey, JwtAlgorithm.RS256)) + val referenceAuthToken = AuthToken("I am token") + + Get("/alic/attempt/?a=2&b=5") + .addHeader(RawHeader(AuthProvider.AuthenticationTokenHeader, referenceAuthToken.value)) + .addHeader(RawHeader(AuthProvider.PermissionsTokenHeader, permissionsToken.value)) ~> + authorize(TestRoleAllowedByTokenPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized by permissions token") } ~> check { handled shouldBe true - responseAs[String] shouldBe "Alright, user \"1\" is authorized" + responseAs[String] shouldBe "Alright, user 1 is authorized by permissions token" } } } diff --git a/src/test/scala/xyz/driver/core/DateTest.scala b/src/test/scala/xyz/driver/core/DateTest.scala index c1185cd..0cf8a9e 100644 --- a/src/test/scala/xyz/driver/core/DateTest.scala +++ b/src/test/scala/xyz/driver/core/DateTest.scala @@ -38,7 +38,7 @@ class DateTest extends FlatSpec with Matchers with Checkers { case Seq(a, b) => if (a.year == b.year) { if (a.month == b.month) { - a.day < b.day + a.day <= b.day } else { a.month < b.month } diff --git a/src/test/scala/xyz/driver/core/FileTest.scala b/src/test/scala/xyz/driver/core/FileTest.scala index a8379cf..a1b0329 100644 --- a/src/test/scala/xyz/driver/core/FileTest.scala +++ b/src/test/scala/xyz/driver/core/FileTest.scala @@ -1,15 +1,13 @@ package xyz.driver.core -import java.io.File +import java.io.{File, FileInputStream} import java.nio.file.Paths -import com.amazonaws.services.s3.AmazonS3 -import com.amazonaws.services.s3.model._ import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar import org.scalatest.{FlatSpec, Matchers} -import xyz.driver.core.file.{FileSystemStorage, S3Storage} +import xyz.driver.core.file.{FileSystemStorage, GcsStorage, S3Storage} import scala.concurrent.Await import scala.concurrent.duration._ @@ -17,6 +15,8 @@ import scala.concurrent.duration._ class FileTest extends FlatSpec with Matchers with MockitoSugar { "S3 Storage" should "create and download local files and do other operations" in { + import com.amazonaws.services.s3.AmazonS3 + import com.amazonaws.services.s3.model._ import scala.collection.JavaConverters._ val tempDir = System.getProperty("java.io.tmpdir") @@ -116,6 +116,69 @@ class FileTest extends FlatSpec with Matchers with MockitoSugar { filesAfterRemoval shouldBe empty } + "Google Cloud Storage" should "upload and download files" in { + import com.google.cloud.Page + import com.google.cloud.storage.{Blob, Bucket, Storage} + import Bucket.BlobWriteOption + import Storage.BlobListOption + import scala.collection.JavaConverters._ + + val tempDir = System.getProperty("java.io.tmpdir") + val sourceTestFile = generateTestLocalFile(tempDir) + val testFileName = "uploadTestFile" + + val randomFolderName = java.util.UUID.randomUUID().toString + val testDirPath = Paths.get(randomFolderName) + val testFilePath = Paths.get(randomFolderName, testFileName) + + val testBucket = Name[Bucket]("IamBucket") + val gcsMock = mock[Storage] + val pageMock = mock[Page[Blob]] + val bucketMock = mock[Bucket] + val blobMock = mock[Blob] + + when(blobMock.getName).thenReturn(testFileName) + when(blobMock.getGeneration).thenReturn(1000L) + when(blobMock.getUpdateTime).thenReturn(1493422254L) + when(blobMock.getSize).thenReturn(12345L) + when(blobMock.getContent()).thenReturn(Array[Byte](1, 2, 3)) + + val gcsStorage = new GcsStorage(gcsMock, testBucket, scala.concurrent.ExecutionContext.global) + + when(pageMock.iterateAll()).thenReturn( + Iterator[Blob]().asJava, + Iterator[Blob](blobMock).asJava, + Iterator[Blob]().asJava + ) + when( + gcsMock.list(testBucket.value, BlobListOption.currentDirectory(), BlobListOption.prefix(testDirPath.toString))) + .thenReturn(pageMock) + + val filesBefore = Await.result(gcsStorage.list(testDirPath).run, 10 seconds) + filesBefore shouldBe empty + + when(gcsMock.get(testBucket.value)).thenReturn(bucketMock) + when(gcsMock.get(testBucket.value, testFilePath.toString)).thenReturn(blobMock) + when(bucketMock.create(org.mockito.Matchers.eq(testFileName), any[FileInputStream], any[BlobWriteOption])) + .thenReturn(blobMock) + + Await.result(gcsStorage.upload(sourceTestFile, testFilePath), 10 seconds) + + val filesAfterUpload = Await.result(gcsStorage.list(testDirPath).run, 10 seconds) + filesAfterUpload.size should be(1) + + val downloadedFile = Await.result(gcsStorage.download(testFilePath).run, 10 seconds) + downloadedFile shouldBe defined + downloadedFile.foreach { + _.getAbsolutePath should endWith(testFilePath.toString) + } + + Await.result(gcsStorage.delete(testFilePath), 10 seconds) + + val filesAfterRemoval = Await.result(gcsStorage.list(testDirPath).run, 10 seconds) + filesAfterRemoval shouldBe empty + } + private def generateTestLocalFile(path: String): File = { val randomSourceFolderName = java.util.UUID.randomUUID().toString val sourceTestFile = new File(Paths.get(path, randomSourceFolderName, "uploadTestFile").toString) |