Files
osm-search/api/api.py

345 lines
10 KiB
Python

from psycopg2 import sql
import psycopg2
from psycopg2.extras import RealDictCursor
from flask import Flask, request, jsonify, abort, Response
from flask_cors import CORS
import json
from functools import wraps
import os
from loguru import logger
from datetime import datetime
import math
import firebase_admin
from firebase_admin import credentials
from firebase_admin import auth
from firebase_admin import firestore
import time
cred = credentials.Certificate("service_account.json")
firebase_app = firebase_admin.initialize_app(cred)
db = firestore.client()
app = Flask(__name__)
app.config["TEMPLATES_AUTO_RELOAD"] = True
CORS(app)
ALLOWED_COMPARISONS = [
"=",
"!=",
">",
"<",
">=",
"<=",
"starts with",
"ends with",
"contains",
"does not contain",
"is null",
"is not null",
]
ALLOWED_METHODS = ["OR", "AND"]
ALLOWED_CASTS = ["integer", "float", "cast_to_int", "cast_to_float"]
def get_db_connection():
conn = psycopg2.connect(
database=os.environ.get("PG_DB"),
host=os.environ.get("PG_HOST"),
port=os.environ.get("PG_PORT"),
user=os.environ.get("PG_USER"),
password=os.environ.get("PG_PASSWORD"),
)
return conn
def query_with_timing(query, conn=None):
if conn is None:
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
cur.execute("SET SESSION statement_timeout = '100s';")
t1 = datetime.now()
try:
cur.execute(query)
except psycopg2.errors.QueryCanceled:
logger.warning("Request timed out")
return Response(status=400)
data = cur.fetchall()
cur.close()
conn.close()
t2 = datetime.now()
for d in data:
d.pop('point_geom')
return (data, t2 - t1)
def get_user(request):
token = None
if "Authorization" in request.headers:
token = request.headers["Authorization"].split(" ")[1]
if not token:
return None
idinfo = auth.verify_id_token(token)
return idinfo
def token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = None
if "Authorization" in request.headers:
token = request.headers["Authorization"].split(" ")[1]
if not token:
logger.info("Request with missing authentication token.")
return {
"message": "Authentication Token is missing!",
"data": None,
"error": "Unauthorized",
}, 401
try:
idinfo = auth.verify_id_token(token)
if idinfo is None:
logger.warning(f"Invalid authentication token {token}")
return {
"message": "Invalid Authentication token!",
"data": None,
"error": "Unauthorized",
}, 403
except Exception as e:
logger.warning(f"Other error {e}")
return {
"message": "Something went wrong",
"data": None,
"error": str(e),
}, 403
logger.info(f"Authenticated request by {idinfo['email']}")
return f(*args, **kwargs)
return decorated
def make_filter_query(filter):
filter_query = sql.SQL("")
i = 0
for subfilter in filter["filters"]:
if subfilter["comparison"] not in ALLOWED_COMPARISONS:
logger.error(f"Invalid comparison {subfilter['comparison']}")
break
if subfilter["comparison"] == "=":
filter_query = sql.SQL("{filter_query} (tags @> {match})").format(
filter_query=filter_query,
match=sql.Literal(subfilter["parameter"] + "=>" + subfilter["value"]),
)
elif subfilter["comparison"] == "!=":
filter_query = sql.SQL("{filter_query} NOT(tags @> {match})").format(
filter_query=filter_query,
match=sql.Literal(subfilter["parameter"] + "=>" + subfilter["value"]),
)
elif subfilter["comparison"] == "is null":
filter_query = sql.SQL("{filter_query} NOT(tags?{parameter})").format(
filter_query=filter_query, parameter=sql.Literal(subfilter["parameter"])
)
elif subfilter["comparison"] == "is not null":
filter_query = sql.SQL("{filter_query} (tags?{parameter})").format(
filter_query=filter_query, parameter=sql.Literal(subfilter["parameter"])
)
else:
parameter = sql.SQL("tags->{parameter}").format(
parameter=sql.Literal(subfilter["parameter"])
)
if "cast" in subfilter and subfilter["cast"] in ALLOWED_CASTS:
if subfilter["cast"] == "cast_to_float":
parameter = sql.SQL("cast_to_float({parameter}, 0.0)").format(
parameter=parameter
)
elif subfilter["cast"] == "cast_to_int":
parameter = sql.SQL("cast_to_int({parameter}, 0)").format(
parameter=parameter
)
else:
parameter = sql.SQL("CAST({parameter} AS {cast})").format(
parameter=parameter, cast=sql.SQL(subfilter["cast"])
)
if subfilter["comparison"] == "starts with":
subfilter["value"] = f"{subfilter['value']}%"
subfilter["comparison"] = "ILIKE"
elif subfilter["comparison"] == "ends with":
subfilter["value"] = f"%{subfilter['value']}"
subfilter["comparison"] = "ILIKE"
elif subfilter["comparison"] == "contains":
subfilter["value"] = f"%{subfilter['value']}%"
subfilter["comparison"] = "ILIKE"
elif subfilter["comparison"] == "does not contain":
subfilter["value"] = f"%{subfilter['value']}%"
subfilter["comparison"] = "NOT ILIKE"
filter_query = sql.SQL(
"{filter_query} ({parameter} {comparison} {value})"
).format(
filter_query=filter_query,
parameter=parameter,
comparison=sql.SQL(subfilter["comparison"]),
value=sql.Literal(subfilter["value"]),
)
if i != len(filter["filters"]) - 1:
filter_query = sql.SQL("{filter_query} {method}").format(
filter_query=filter_query, method=sql.SQL(filter["method"])
)
i += 1
return filter_query
@app.route("/intersection")
@token_required
def get_intersection():
args = request.args
buffer = int(args.get("buffer"))
filters = json.loads(args.get("filters"))
l = float(args.get("l"))
b = float(args.get("b"))
r = float(args.get("r"))
t = float(args.get("t"))
bbox = [l, b, r, t]
area = (
math.pow(6371, 2)
* math.pi
* abs(math.sin(math.radians(t)) - math.sin(math.radians(b)))
* abs(r - l)
/ 180
)
# reject queries that are too large
if area > 4e6:
return Response(status=400)
bbox_filter = sql.SQL(
"AND (way && ST_Transform(ST_MakeEnvelope({left}, {bottom}, {right}, {top}, 4326), 3857))"
).format(
left=sql.Literal(bbox[0]),
bottom=sql.Literal(bbox[1]),
right=sql.Literal(bbox[2]),
top=sql.Literal(bbox[3]),
)
first = filters[0]
first_query = sql.SQL(
"SELECT tags->'name' AS name, ST_Centroid(way) AS point_geom, way AS geom FROM {table}"
).format(
table=sql.SQL("planet_osm_line")
if first["type"] == "line"
else sql.SQL("planet_osm_polygon")
if first["type"] == "polygon"
else sql.SQL("planet_osm_point")
if first["type"] == "point"
else sql.SQL("planet_osm")
)
first_filter = make_filter_query(first)
first_assembled = sql.SQL("{query} WHERE ({filter}) {bbox}").format(
query=first_query, filter=first_filter, bbox=bbox_filter
)
logger.info(f"Buffer: {buffer}\tFilters: {filters}\tBbox: [{l},{b},{r},{t}]")
subqueries = []
for f in filters[1:]:
filter = make_filter_query(f)
if f["type"] == "point":
query = sql.SQL("SELECT way AS geom FROM planet_osm_point")
elif f["type"] == "line":
query = sql.SQL("SELECT way AS geom FROM planet_osm_line")
elif f["type"] == "polygon":
query = sql.SQL("SELECT way AS geom FROM planet_osm_polygon")
elif f["type"] == "any":
query = sql.SQL("SELECT way AS geom FROM planet_osm")
assembled = sql.SQL("{query} WHERE ({filter}) {bbox}").format(
query=query, filter=filter, bbox=bbox_filter
)
subqueries.append(assembled)
join_query = sql.SQL(
"SELECT DISTINCT point_geom, name, ST_Y(ST_Transform(point_geom, 4326)) AS lat, ST_X(ST_Transform(point_geom, 4326)) as lng FROM ({point}) point "
).format(point=first_assembled)
i = 0
for q in subqueries:
join_query = sql.SQL(
"{join_query} JOIN ({q}) {subindex} ON ST_DWithin(point.geom, {subindex}.geom, {buffer})"
).format(
join_query=join_query,
q=q,
subindex=sql.SQL("subquery" + str(i)),
buffer=sql.Literal(buffer),
)
i += 1
join_query = sql.SQL("{join_query} LIMIT 100").format(join_query=join_query)
conn = get_db_connection()
logger.info(f"Executing query: {join_query.as_string(conn)}")
timestamp = time.time_ns()
user = get_user(request)
data, tdiff = query_with_timing(join_query)
log_data = {
"timestamp": firestore.SERVER_TIMESTAMP,
"query": join_query.as_string(conn),
"user_uid": user["uid"],
"user_email": user["email"],
"bbox": bbox,
"filters": filters,
"query_time": tdiff.total_seconds(),
"query_nresults": len(data),
"query_results": data,
}
db.collection("searches").document(str(timestamp)).set(log_data)
logger.info(f"Found {len(data)} results in {tdiff} seconds")
return jsonify(data)
@app.route("/robots.txt")
def robots():
return Response("User-agent: *\nDisallow: /", mimetype="text/plain")
def start():
app.run(port=5050)
if __name__ == "__main__":
start()