summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xsrc/wsgi.py60
1 files changed, 42 insertions, 18 deletions
diff --git a/src/wsgi.py b/src/wsgi.py
index b690bc0..5f56fb9 100755
--- a/src/wsgi.py
+++ b/src/wsgi.py
@@ -1,3 +1,5 @@
+import os
+import sys
import uvicorn
from fastapi import FastAPI, Depends, Request
@@ -12,19 +14,29 @@ from db import DictDB
app = FastAPI()
db = DictDB()
-public_key = """-----BEGIN PUBLIC KEY-----
-MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEPW8bkkVIq4BX8eWwlUOUYbJhiGDv
-K/6xY5T0BsvV6pbMoIUfgeThVOq5I3CmXxLt+qyPska6ol9fTN7woZLsCg==
------END PUBLIC KEY-----"""
+
+def get_pubkey():
+ try:
+ keypath = os.environ['JWT_PUBKEY_PATH']
+
+ with open(keypath, "r") as fd:
+ pubkey = fd.read()
+ except KeyError:
+ print("Could not find environment variable JWT_PUBKEY_PATH")
+ sys.exit(-1)
+ except FileNotFoundError:
+ print(f"Could not find JWT certificate in {keypath}")
+ sys.exit(-1)
+
+ return pubkey
def get_data(key=None, limit=25, skip=0, ip=None,
- port=None, asn=None):
+ port=None, asn=None, domain=None):
selectors = dict()
indexes = CouchIindex().dict()
-
- selectors['domain'] = 'sunet.se'
+ selectors['domain'] = domain
if ip and 'ip' in indexes:
selectors['ip'] = ip
@@ -35,12 +47,12 @@ def get_data(key=None, limit=25, skip=0, ip=None,
data = db.search(**selectors, limit=limit, skip=skip)
- return JSONResponse(content={"status": "success", "data": data})
+ return data
class JWTConfig(BaseModel):
authjwt_algorithm: str = "ES256"
- authjwt_public_key: str = public_key
+ authjwt_public_key: str = get_pubkey()
@AuthJWT.load_config
@@ -67,7 +79,21 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None,
Authorize.jwt_required()
- return get_data(key, limit, skip, ip, port, asn)
+ data = []
+ raw_jwt = Authorize.get_raw_jwt()
+
+ if 'domains' not in raw_jwt:
+ return JSONResponse(content={"status": "error",
+ "message": "Could not find domains" +
+ "claim in JWT token"},
+ status_code=400)
+ else:
+ domains = raw_jwt['domains']
+
+ for domain in domains:
+ data.extend(get_data(key, limit, skip, ip, port, asn, domain))
+
+ return JSONResponse(content={"statuc": "success", "docs": data})
@app.get('/sc/v0/get/{key}')
@@ -75,19 +101,17 @@ async def get_key(key=None, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
- return get_data(key)
+ data = get_data(key)
+
+ return JSONResponse(content={"statuc": "success", "docs": data})
@app.post('/sc/v0/add')
async def add(data: Request, Authorize: AuthJWT = Depends()):
- Authorize.jwt_required()
-
- orgs = ['sunet.se']
-
- if not orgs:
- return JSONResponse(content={"status": "error", "message":
- "Could not find an organization"}, status_code=400)
+ # Maybe we should protect this enpoint too and let the scanner use
+ # a JWT token as well.
+ # Authorize.jwt_required()
json_data = await data.json()