qubed/stac_server/main.py
2024-11-21 14:09:45 +00:00

253 lines
8.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import yaml
from pathlib import Path
import os
from datetime import datetime
from collections import defaultdict
from typing import Any, Dict
import yaml
import os
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse, FileResponse
from fastapi.templating import Jinja2Templates
from TreeTraverser.fdb_schema import FDBSchemaFile
from TreeTraverser.CompressedTree import CompressedTree
import redis
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get('/favicon.ico', include_in_schema=False)
async def favicon():
return FileResponse("favicon.ico")
print("Getting cache from redis")
r = redis.Redis(host=os.environ.get("REDIS_HOST", "localhost"), port=6379, db=0)
json_data = r.get('compressed_catalog')
if not json_data:
raise ValueError("No compressed catalog found in Redis")
else:
print("Found compressed catalog in Redis")
print("Loading tree to json")
compressed_tree_json = json.loads(json_data)
c_tree = CompressedTree.from_json(compressed_tree_json)
print("Partialy decompressing tree, shoud be able to skip this step in future.")
tree = c_tree.reconstruct_compressed_ecmwf_style()
print("Ready to serve requests!")
base = os.environ.get("CONFIG_DIR", ".")
config = {
"fdb_schema": f"{base}/schema",
"mars_language": f"{base}/language.yaml",
}
with open(config["mars_language"], "r") as f:
mars_language = yaml.safe_load(f)["_field"]
###### Load FDB Schema
schema = FDBSchemaFile(config["fdb_schema"])
def request_to_dict(request: Request) -> Dict[str, Any]:
# Convert query parameters to dictionary format
request_dict = dict(request.query_params)
for key, value in request_dict.items():
# Convert comma-separated values into lists
if "," in value:
request_dict[key] = value.split(",")
return request_dict
def match_against_cache(request, tree):
if not tree: return {"_END_" : {}}
matches = {}
for k, subtree in tree.items():
if len(k.split("=")) != 2:
raise ValueError(f"Key {k} is not in the correct format")
key, values = k.split("=")
values = set(values.split(","))
if key in request:
if isinstance(request[key], list):
matching_values = ",".join(request_value for request_value in request[key] if request_value in values)
if matching_values:
matches[f"{key}={matching_values}"] = match_against_cache(request, subtree)
elif request[key] in values:
matches[f"{key}={request[key]}"] = match_against_cache(request, subtree)
if not matches: return {k : {} for k in tree.keys()}
return matches
def max_tree_depth(tree):
"Figure out the maximum depth of a tree"
if not tree:
return 0
return 1 + max(max_tree_depth(v) for v in tree.values())
def prune_short_branches(tree, depth = None):
if depth is None:
depth = max_tree_depth(tree)
return {k : prune_short_branches(v, depth-1) for k, v in tree.items() if max_tree_depth(v) == depth-1}
def get_paths_to_leaves(tree):
for k,v in tree.items():
if not v:
yield [k,]
else:
for leaf in get_paths_to_leaves(v):
yield [k,] + leaf
def get_leaves(tree):
for k,v in tree.items():
if not v:
yield k
else:
for leaf in get_leaves(v):
yield leaf
@app.get("/match")
async def get_match(request: Request):
# Convert query parameters to dictionary format
request_dict = request_to_dict(request)
# Run the schema matching logic
match_tree = match_against_cache(request_dict, tree)
# Prune the tree to only include branches that are as deep as the deepest match
# This means if you don't choose a certain branch at some point
# the UI won't keep nagging you to choose a value for that branch
match_tree = prune_short_branches(match_tree)
return match_tree
@app.get("/paths")
async def api_paths(request: Request):
request_dict = request_to_dict(request)
match_tree = match_against_cache(request_dict, tree)
match_tree = prune_short_branches(match_tree)
paths = get_paths_to_leaves(match_tree)
# deduplicate leaves based on the key
by_path = defaultdict(lambda : {"paths" : set(), "values" : set()})
for p in paths:
if p[-1] == "_END_": continue
key, values = p[-1].split("=")
values = values.split(",")
path = tuple(p[:-1])
by_path[key]["values"].update(values)
by_path[key]["paths"].add(tuple(path))
return [{
"paths": list(v["paths"]),
"key": key,
"values": sorted(v["values"], reverse=True),
} for key, v in by_path.items()]
@app.get("/stac")
async def get_STAC(request: Request):
request_dict = request_to_dict(request)
paths = await api_paths(request)
# # Run the schema matching logic
# matches = schema.match_all(dict(v.split("=") for v in path))
# # Only take the longest matches
# max_len = max(len(m) for m in matches)
# matches = [m for m in matches if len(m) == max_len]
# # Take the ends of all partial matches, ignore those that are full matches
# # Full matches are indicated by the last key having boolean value True
# key_frontier = defaultdict(list)
# for match in matches:
# if not match[-1]:
# key_frontier[match[-1].key].append([m for m in match[:-1]])
def make_link(key_name, paths, values):
"""Take a MARS Key and information about which paths matched up to this point and use it to make a STAC Link"""
path = paths[0]
href_template = f"/stac?{'&'.join(path)}{'&' if path else ''}{key_name}={{}}"
optional = [False]
optional_str = "Yes" if all(optional) and len(optional) > 0 else ("Sometimes" if any(optional) else "No")
values_from_mars_language = mars_language.get(key_name, {}).get("values", [])
# values = [v[0] if isinstance(v, list) else v for v in values_from_mars_language]
if all(isinstance(v, list) for v in values_from_mars_language):
value_descriptions_dict = {k : v[-1]
for v in values_from_mars_language
if len(v) > 1
for k in v[:-1]}
value_descriptions = [value_descriptions_dict.get(v, "") for v in values]
if not any(value_descriptions): value_descriptions = None
return {
"title": key_name,
"generalized_datacube:href_template": href_template,
"rel": "child",
"type": "application/json",
"generalized_datacube:dimension" : {
"type" : mars_language.get(key_name, {}).get("type", ""),
"description": mars_language.get(key_name, {}).get("description", ""),
"values" : values,
"value_descriptions" : value_descriptions,
"optional" : any(optional),
"multiple": True,
"paths" : paths,
}
}
def value_descriptions(key, values):
return {
v[0] : v[-1] for v in mars_language.get(key, {}).get("values", [])
if len(v) > 1 and v[0] in values
}
descriptions = {
key : {
"key" : key,
"values" : values,
"description" : mars_language.get(key, {}).get("description", ""),
"value_descriptions" : value_descriptions(key,values),
}
for key, values in request_dict.items()
}
# Format the response as a STAC collection
stac_collection = {
"type": "Collection",
"stac_version": "1.0.0",
"id": "partial-matches",
"description": "STAC collection representing potential children of this request",
"links": [
make_link(p["key"], p["paths"], p["values"])
for p in paths
],
"debug": {
"request": request_dict,
"descriptions": descriptions,
"paths" : paths,
}
}
return stac_collection