From 104480d365c3958d7f676b087ff03ed0912410d7 Mon Sep 17 00:00:00 2001 From: mprahl Date: Wed, 8 Nov 2017 21:37:25 -0500 Subject: [PATCH] Set Access-Control-Allow-Origin to * on GET API routes --- module_build_service/utils.py | 27 ++++++++++++++++++++++++++- module_build_service/views.py | 5 ++++- tests/test_views/test_views.py | 4 ++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/module_build_service/utils.py b/module_build_service/utils.py index edc3b7c7..8447dc41 100644 --- a/module_build_service/utils.py +++ b/module_build_service/utils.py @@ -32,11 +32,12 @@ import os import kobo.rpmlib import inspect import hashlib +from functools import wraps import modulemd import yaml -from flask import request, url_for +from flask import request, url_for, Response from datetime import datetime from module_build_service import log, models @@ -1426,3 +1427,27 @@ def create_dogpile_key_generator_func(skip_first_n_args=0): return generate_key return key_generator + + +def cors_header(allow='*'): + """ + A decorator that sets the Access-Control-Allow-Origin header to the desired value on a Flask + route + :param allow: a string of the domain to allow. This defaults to '*'. + """ + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + rv = func(*args, **kwargs) + if rv: + # If a tuple was provided, then the Flask Response should be the first object + if isinstance(rv, tuple): + response = rv[0] + else: + response = rv + # Make sure we are dealing with a Flask Response object + if isinstance(response, Response): + response.headers.add('Access-Control-Allow-Origin', allow) + return rv + return wrapper + return decorator diff --git a/module_build_service/views.py b/module_build_service/views.py index 01dae1d8..3c0f2fb2 100644 --- a/module_build_service/views.py +++ b/module_build_service/views.py @@ -35,7 +35,7 @@ from module_build_service import app, conf, log, models, db, version from module_build_service.utils import ( pagination_metadata, filter_module_builds, filter_component_builds, submit_module_build_from_scm, submit_module_build_from_yaml, - get_scm_url_re) + get_scm_url_re, cors_header) from module_build_service.errors import ( ValidationError, Forbidden, NotFound, ProgrammingError) @@ -90,6 +90,7 @@ api_v1 = { class AbstractQueryableBuildAPI(MethodView): """ An abstract class, housing some common functionality. """ + @cors_header() def get(self, id): verbose_flag = request.args.get('verbose', 'false').lower() @@ -199,6 +200,7 @@ class ModuleBuildAPI(AbstractQueryableBuildAPI): class AboutAPI(MethodView): + @cors_header() def get(self): json = {'version': version} config_items = ['auth_method'] @@ -213,6 +215,7 @@ class AboutAPI(MethodView): class RebuildStrategies(MethodView): + @cors_header() def get(self): items = [] # Sort the items list by name diff --git a/tests/test_views/test_views.py b/tests/test_views/test_views.py index 04b4b431..62808731 100644 --- a/tests/test_views/test_views.py +++ b/tests/test_views/test_views.py @@ -991,3 +991,7 @@ class TestViews(unittest.TestCase): ] } self.assertEquals(data, expected) + + def test_cors_header_decorator(self): + rv = self.client.get('/module-build-service/1/module-builds/') + self.assertEquals(rv.headers['Access-Control-Allow-Origin'], '*')