Browse Source

Rewrite point generator almost from scratch to use aiohttp

Kirk Trombley 4 years ago
parent
commit
35a5c78ed8

+ 6 - 0
server/app/__init__.py

@@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware
 
 from .api import other, game
 from .db import init_db
+from .point_gen import aiohttp_client
 
 logging.config.fileConfig('logging.conf', disable_existing_loggers=False)
 
@@ -26,3 +27,8 @@ app.add_middleware(
 @app.on_event("startup")
 def startup():
     init_db(os.environ.get("SQLALCHEMY_URL", "sqlite:////tmp/terrassumptions.db"), connect_args={"check_same_thread": False})
+
+
+@app.on_event("shutdown")
+async def shutdown_event():
+    await aiohttp_client.close()

+ 9 - 4
server/app/api/game.py

@@ -1,3 +1,5 @@
+import logging
+
 from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
 from fastapi_camelcase import CamelModel
 from pydantic import conint, constr
@@ -7,7 +9,9 @@ from sqlalchemy.orm import Session
 from .. import scoring
 from ..schemas import GameConfig, Guess, RuleSetEnum
 from ..db import get_db, queries, models
-from ..point_gen import generate_points, restock_source, ExhaustedSourceError
+from ..point_gen import points, ExhaustedSourceError
+
+logger = logging.getLogger(__name__)
 
 router = APIRouter()
 
@@ -27,14 +31,15 @@ class GuessResult(CamelModel):
 
 
 @router.put("")
-def create_game(config: GameConfig, bg: BackgroundTasks, db: Session = Depends(get_db)):
+async def create_game(config: GameConfig, bg: BackgroundTasks, db: Session = Depends(get_db)):
     try:
-        coords = generate_points(config)
+        coords = await points.get_points(config)
     except ExhaustedSourceError:
+        logger.exception("")
         # TODO might be worth logging something useful here eventually
         raise HTTPException(status_code=501, detail="Sufficient points could not be generated quickly enough")
     game_id = queries.create_game(db, config, coords)
-    bg.add_task(restock_source, config)
+    bg.add_task(points.restock_source, config)
     return { "gameId": game_id }
 
 

+ 3 - 3
server/app/api/other.py

@@ -6,7 +6,7 @@ from pydantic import confloat
 
 from .. import scoring
 from ..schemas import CacheInfo, GeneratorInfo
-from ..point_gen import get_cache_info, get_generators
+from ..point_gen import points, generator_info
 
 router = APIRouter()
 
@@ -47,9 +47,9 @@ def check_score(points: ScoreCheck):
 
 @router.get("/caches", response_model=CacheResponse)
 def caches():
-    return CacheResponse(caches=get_cache_info())
+    return CacheResponse(caches=points.get_cache_info())
 
 
 @router.get("/generators", response_model=GeneratorResponse)
 def generators():
-    return GeneratorResponse(generators=get_generators())
+    return GeneratorResponse(generators=generator_info)

+ 86 - 43
server/app/point_gen/__init__.py

@@ -1,49 +1,92 @@
-from typing import List, Tuple
+import asyncio
+import collections
+import random
+from typing import List, Tuple, Dict, Union
 
-from .random_street_view import SOURCE_GROUP as RSV_SOURCE_GROUP, VALID_COUNTRIES as RSV_COUNTRIES
-from .urban_centers import SOURCE_GROUP as URBAN_CENTER_SOURCE_GROUP, VALID_COUNTRIES as URBAN_COUNTRIES
-from .shared import ExhaustedSourceError
+from .random_street_view import call_random_street_view, VALID_COUNTRIES as RSV_COUNTRIES
+from .urban_centers import urban_coord_unlocked, urban_coord_ensured, VALID_COUNTRIES as URBAN_COUNTRIES
+from .shared import ExhaustedSourceError, aiohttp_client
 
-from ..schemas import GameConfig, GenMethodEnum, CacheInfo, GeneratorInfo
+from ..schemas import GameConfig, GenMethodEnum, CountryCode, CacheInfo, GeneratorInfo
 
-source_groups = {
-    GenMethodEnum.rsv: RSV_SOURCE_GROUP,
-    GenMethodEnum.urban: URBAN_CENTER_SOURCE_GROUP,
+generator_info = [
+    GeneratorInfo(
+        generation_method=GenMethodEnum.rsv,
+        country_locks=RSV_COUNTRIES
+    ),
+    GeneratorInfo(
+        generation_method=GenMethodEnum.urban,
+        country_locks=URBAN_COUNTRIES
+    ),
+]
+
+cache_names = {
+    GenMethodEnum.rsv: "RSV",
+    GenMethodEnum.urban: "Urban",
 }
 
 
-def generate_points(config: GameConfig) -> List[Tuple[str, float, float]]:
-    """
-    Generate points according to the GameConfig.
-    """
-    # note - force exactly config.rounds points, even though most top level sources should be well-behaved in this regard
-    return source_groups[config.generation_method].get_points_from(config.rounds, config.country_lock)[:config.rounds]
-
-
-def restock_source(config: GameConfig):
-    """
-    Restock any caches associated with the GameConfig.
-    """
-    source_groups[config.generation_method].restock(config.country_lock)
-
-
-def get_cache_info() -> List[CacheInfo]:
-    """
-    Get CacheInfo for all caches.
-    """
-    return [CacheInfo(cache_name=c.get_name(), size=len(c.stock)) for g in source_groups.values() for c in g.cached]
-
-def get_generators() -> List[GeneratorInfo]:
-    """
-    Get all available Generators and their country options
-    """
-    return [
-        GeneratorInfo(
-            generation_method=GenMethodEnum.rsv,
-            country_locks=RSV_COUNTRIES
-        ),
-        GeneratorInfo(
-            generation_method=GenMethodEnum.urban,
-            country_locks=URBAN_COUNTRIES
-        ),
-    ]
+class PointStore:
+    def __init__(self, cache_targets: Dict[Tuple[GenMethodEnum, CountryCode], int]):
+        self.cache_targets = cache_targets
+        self.store = collections.defaultdict(collections.deque)
+
+    async def generate_point(self, generator: GenMethodEnum, country: Union[CountryCode, None]) -> Tuple[str, float, float]:
+        if generator == GenMethodEnum.rsv:
+            # RSV point functions return a collection of points, which should be cached
+            point, *points = await call_random_street_view(country)
+            # use the country on the point - since country itself might be None
+            self.store[(generator, point[0])].extend(points)
+            return point
+        elif generator == GenMethodEnum.urban:
+            # urban center point functions only return a single point
+            if country is None:
+                return await urban_coord_unlocked()
+            return await urban_coord_ensured(country, city_retries=50)
+        else:
+            raise ExhaustedSourceError()
+
+    async def get_point(self, generator: GenMethodEnum, country: Union[CountryCode, None]) -> Tuple[str, float, float]:
+        if country is not None:
+            # if we already have a point ready, just return it immediately
+            # to avoid bias, we only do this in country-locking mode
+            stock = self.store[(generator, country)]
+            if len(stock) > 0:
+                return stock.popleft()
+
+        return await self.generate_point(generator, country)
+
+    async def get_points(self, config: GameConfig) -> List[Tuple[str, float, float]]:
+        """
+        Provide points according to the GameConfig.
+
+        Return a list of at least n valid geo points, as 
+        (2 character country code, latitude, longitude) tuples.
+
+        In the event that the configured source cannot reasonably supply enough points,
+        most likely due to time constraints, this will raise an ExhaustedSourceError.
+        """
+        return await asyncio.gather(*[self.get_point(config.generation_method, config.country_lock) for _ in range(config.rounds)])
+
+    def get_cache_info(self) -> List[CacheInfo]:
+        """
+        Get CacheInfo for all caches.
+        """
+        return [CacheInfo(cache_name=f"{cache_names[g]}-{c}", size=len(ps)) for (g, c), ps in self.store.items()]
+
+    async def restock_source(self, config: GameConfig):
+        """
+        Restock any caches associated with the GameConfig.
+        """
+        if config.country_lock is None:
+            return
+        key = (config.generation_method, config.country_lock)
+        target = self.cache_targets.get(key, 0)
+        stock = self.store[key]
+        while len(stock) < target:
+            stock.append(await self.generate_point(*key))
+
+
+points = PointStore({
+    (GenMethodEnum.urban, "us"): 10,
+})

+ 33 - 61
server/app/point_gen/random_street_view.py

@@ -1,8 +1,7 @@
 import random
+import logging
 
-import requests
-
-from .shared import point_has_streetview, GeoPointSource, CachedGeoPointSource, ExhaustedSourceError, GeoPointSourceGroup
+from .shared import aiohttp_client, point_has_streetview, ExhaustedSourceError
 
 RSV_URL = "https://randomstreetview.com/data"
 VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw", 
@@ -14,69 +13,42 @@ VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw",
                    "sg", "sk", "si", "za", "kr", "es", "sz", 
                    "se", "ch", "tw", "th", "ua", "gb", "us")
 
+logger = logging.getLogger(__name__)
+
 
-def call_random_street_view(country_lock):
+async def call_random_street_view(country_lock, max_attempts=5):
     """
     Returns an array of (some number of) tuples, each being (latitude, longitude).
-    All points will be valid streetview coordinates. There is no guarantee as to the
-    length of this array (it may be empty), but it will never be None.
+    All points will be valid streetview coordinates, in the country indicated by
+    country_lock. If the country_lock provided is None, a valid one is chosen at 
+    random. The returned array will never be empty. If after max_attempts no points 
+    can be found (which is very rare), this raises an ExhaustedSourceError.
 
     This function calls the streetview metadata endpoint - there is no quota consumed.
     """
-    try:
-        rsv_js = requests.post(RSV_URL, data={"country": country_lock.lower()}).json()
-    except:
-        return []
 
-    if not rsv_js["success"]:
-        return []
+    if country_lock is None:
+        country_lock = random.choice(VALID_COUNTRIES)
     
-    return [
-        (country_lock, point["lat"], point["lng"])
-        for point in rsv_js["locations"]
-        if point_has_streetview(point["lat"], point["lng"])
-    ]
-
-
-class RSVCountryPointSource(GeoPointSource):
-    def __init__(self, country_lock, max_attempts=100):
-        self.country_lock = country_lock
-        self.max_attempts = max_attempts
-
-    def get_name(self):
-        return f"RSV-{self.country_lock}"
-
-    def get_points(self, n):
-        attempts = 0
-        points = []
-        while len(points) < n:
-            if attempts > self.max_attempts:
-                raise ExhaustedSourceError(points)
-            points.extend(call_random_street_view(country_lock=self.country_lock))
-            attempts += 1
-        return points
-
-
-COUNTRY_SOURCES = { k: CachedGeoPointSource(RSVCountryPointSource(k), 10) for k in VALID_COUNTRIES }
-
-
-class RSVWorldPointSource(GeoPointSource):
-    def __init__(self, max_attempts=10):
-        self.max_attempts = max_attempts
-
-    def get_name(self):
-        return "RSV-global"
-
-    def get_points(self, n):
-        attempts = 0
-        points = []
-        while len(points) < n:
-            if attempts > self.max_attempts:
-                raise ExhaustedSourceError(points)
-            points.extend(COUNTRY_SOURCES[random.choice(VALID_COUNTRIES)].get_points(1))
-            attempts += 1
-        return points
-
-
-WORLD_SOURCE = RSVWorldPointSource()
-SOURCE_GROUP = GeoPointSourceGroup(COUNTRY_SOURCES, WORLD_SOURCE)
+    for _ in range(max_attempts):
+        logger.info("Attempting RSV...")
+        try:
+            async with aiohttp_client.post(RSV_URL, data={"country": country_lock.lower()}) as response:
+                rsv_js = await response.json(content_type=None)
+                logger.info(f"Got back {rsv_js.keys()}")
+        except:
+            logger.exception("Failed RSV")
+            continue
+
+        if not rsv_js["success"]:
+            continue
+        
+        points = [
+            (country_lock, point["lat"], point["lng"])
+            for point in rsv_js["locations"]
+            if await point_has_streetview(point["lat"], point["lng"])
+        ]
+        if len(points) > 0:
+            return points
+    else:
+        raise ExhaustedSourceError()

+ 11 - 109
server/app/point_gen/shared.py

@@ -2,7 +2,9 @@ from typing import List, Tuple, Union, Dict
 import collections
 import logging
 
-import requests
+import aiohttp
+
+from ..schemas import GenMethodEnum, CountryCode
 
 # Google API key, with access to Street View Static API
 # this can be safely committed due to permission restriction
@@ -11,125 +13,25 @@ metadata_url = "https://maps.googleapis.com/maps/api/streetview/metadata"
 
 logger = logging.getLogger(__name__)
 
+aiohttp_client = aiohttp.ClientSession()
+
 
-def point_has_streetview(lat, lng):
+async def point_has_streetview(lat, lng):
     """
     Returns True if the streetview metadata endpoint says a given point has
     data available, and False otherwise.
 
     This function calls the streetview metadata endpoint - there is no quota consumed.
     """
-    return requests.get(metadata_url, params={
+    params = {
         "key": google_api_key,
         "location": f"{lat},{lng}",
-    }).json()["status"] == "OK"
+    }
+    async with aiohttp_client.get(metadata_url, params=params) as response:
+        body = await response.json()
+        return body["status"] == "OK"
 
 
 class ExhaustedSourceError(Exception):
     def __init__(self, partial=[]):
         self.partial = partial
-
-
-class GeoPointSource:
-    """
-    Abstract base class for a source of geo points
-    """
-    def get_name(self) -> str:
-        """
-        Return a human-readable name for this point source, for debugging purposes.
-        """
-        raise NotImplemented("Must be implemented by subclasses")
-
-    def get_points(self, n: int) -> List[Tuple[str, float, float]]:
-        """
-        Return a list of at least n valid geo points, as 
-        (2 character country code, latitude, longitude) tuples.
-        In the event that the GeoPointSource cannot reasonably supply enough points,
-        most likely due to time constraints, it should raise an ExhaustedSourceError.
-        """
-        raise NotImplemented("Must be implemented by subclasses")
-
-
-class CachedGeoPointSource(GeoPointSource):
-    """
-    Wrapper tool for maintaing a cache of points from a GeoPointSource to
-    make get_points faster, at the exchange of needing to restock those
-    points after the fact. This can be done in another thread, however, to
-    hide this cost from the user.
-    """
-
-    def __init__(self, source: GeoPointSource, stock_target: int):
-        self.source = source
-        self.stock = collections.deque()
-        self.stock_target = stock_target
-
-    def get_name(self):
-        return f"Cached({self.source.get_name()}, {self.stock_target})"
-
-    def restock(self, n: Union[int, None] = None):
-        """
-        Restock at least n points into this source.
-        If n is not provided, it will default to stock_target, as set during the
-        construction of this point source.
-        """
-        n = n if n is not None else self.stock_target - len(self.stock)
-        if n > 0:
-            logger.info(f"Restocking {self.get_name()} with {n} points")
-            try:
-                pts = self.source.get_points(n)
-            except ExhaustedSourceError as e:
-                pts = e.partial  # take what we can get
-            self.stock.extend(pts)
-            diff = n - len(pts)
-            if diff > 0:
-                # if implementations of source.get_points are well behaved, this will
-                # never actually need to recurse to finish the job.
-                self.restock(n=diff)
-            logger.info(f"Finished restocking {self.get_name()}")
-
-    def get_points(self, n: int) -> List[Tuple[str, float, float]]:
-        """
-        Pull n points from the current stock.
-        It is recommended to call CachedGeoPointSource.restock after this, to ensure 
-        the stock is not depleted. If possible, calling restock in another thread is
-        recommended, as it can be a long operation depending on implementation.
-        """
-        if len(self.stock) >= n:
-            pts = []
-            for _ in range(n):
-                pts.append(self.stock.popleft())
-            return pts
-        self.restock(n=n)
-        # this is safe as long as restock does actually add enough new points.
-        # unless this object is being rapidly drained by another thread,
-        # this will recur at most once.
-        return self.get_points(n=n)
-
-
-class GeoPointSourceGroup:
-    """
-    Container of multiple GeoPointSources, each with some key.
-    """
-    def __init__(self, sources: Dict[str, GeoPointSource], default: GeoPointSource):
-        self.sources = sources
-        self.default = default
-        self.cached = [s for s in sources.values() if isinstance(s, CachedGeoPointSource)]
-        if isinstance(default, CachedGeoPointSource):
-            self.cached.append(default)
-
-    def restock(self, key: Union[str, None] = None):
-        """
-        Restock a CachedGeoPointSources managed by this group.
-        If the targeted GeoPointSource is uncached, this method does nothing.
-        """
-        src = self.sources.get(key, self.default)
-        if isinstance(src, CachedGeoPointSource):
-            src.restock()
-
-    def get_points_from(self, n: int, key: Union[str, None] = None) -> List[Tuple[str, float, float]]:
-        """
-        Return a list of at least n valid geo points, for a given key. If no key 
-        is provided, or no matching GeoPointSource is found, the default 
-        GeoPointSource will be used.
-        """
-        return self.sources.get(key, self.default).get_points(n)

+ 40 - 67
server/app/point_gen/urban_centers.py

@@ -2,15 +2,15 @@ import math
 import random
 import csv
 import logging
+import asyncio
 from collections import defaultdict
 
-from .shared import point_has_streetview, GeoPointSource, CachedGeoPointSource, GeoPointSourceGroup
+from .shared import point_has_streetview, ExhaustedSourceError
 from ..scoring import mean_earth_radius_km
 
 logger = logging.getLogger(__name__)
-URBAN_CENTERS = defaultdict(list)
-
 
+URBAN_CENTERS = defaultdict(list)
 _found_countries = set()
 _urban_center_count = 0
 with open("./data/urban-centers.csv") as infile:
@@ -22,7 +22,7 @@ logger.info(f"Read {_urban_center_count} urban centers from {len(_found_countrie
 VALID_COUNTRIES = tuple(_found_countries)
 
 
-def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_km=25):
+async def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_km=25):
     """
     Returns (latitude, longitude) of usable coord (where google has data) that is near
     a known urban center. Points will be at most max_dist_km kilometers away. This function 
@@ -60,70 +60,43 @@ def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_km=25)
             pt_lng_rad = city_lng_rad + math.atan2(math.sin(angle_rad) * sin_dor * cos_lat, cos_dor - sin_lat * math.sin(pt_lat_rad))
             pt_lat = math.degrees(pt_lat_rad)
             pt_lng = math.degrees(pt_lng_rad)
-            if point_has_streetview(pt_lat, pt_lng):
+            if await point_has_streetview(pt_lat, pt_lng):
                 logger.info("Point found!")
                 return (country_lock, pt_lat, pt_lng)
 
 
-class WorldUrbanPointSource(GeoPointSource):
-    def __init__(self, country_retries=30):
-        self.country_retries = country_retries
-    
-    def get_name(self):
-        return "Urban-global"
-
-    def get_points(self, n):
-        # Will make at most self.country_retries * n attempts to call urban_coord
-        points = []
-        for _ in range(n):
-            countries = random.sample(URBAN_CENTERS.keys(), k=min(self.country_retries, len(URBAN_CENTERS)))
-            for country in countries:
-                logger.info(f"Selecting urban centers from {c}")
-                pt = urban_coord(c)
-                if pt is not None:
-                    points.append(pt)
-                    break
-            else:
-                raise ExhaustedSourceError(points)
-        return points
-
-
-class CountryUrbanPointSource(GeoPointSource):
-    def __init__(self, country_lock, max_attempts=5):
-        self.country_lock = country_lock
-        self.max_attempts = max_attempts
-
-    def get_name(self):
-        return f"Urban-{self.country_lock}"
-
-    def get_points(self, n):
-        # Will make at most self.max_attempts * n calls to urban_coord with 100 city retries each
-        attempts = 0
-        points = []
-        for _ in range(n):
-            for _ in range(self.max_attempts):
-                pt = urban_coord(
-                    city_retries=100,
-                    country_lock=self.country_lock,
-                )
-                if pt is not None:
-                    points.append(pt)
-                    break
-            else:
-                raise ExhaustedSourceError(points)
-        return points
-
-
-class CountryUrbanSourceDict(dict):
-    def get(self, key, default):
-        if key is None:
-            return default
-        if key not in self:
-            self[key] = CountryUrbanPointSource(key)
-        return self[key]
-
-
-WORLD_SOURCE = CachedGeoPointSource(WorldUrbanPointSource(), 20)
-COUNTRY_SOURCES = CountryUrbanSourceDict()
-COUNTRY_SOURCES["us"] = CachedGeoPointSource(CountryUrbanPointSource("us"), 20) # cache US
-SOURCE_GROUP = GeoPointSourceGroup(COUNTRY_SOURCES, WORLD_SOURCE)
+async def urban_coord_unlocked(country_retries=30, city_retries=10, point_retries=10, max_dist_km=25):
+    """
+    The same behavior as urban_coord, but for a randomly chosen country. Will attempt at most
+    country_retries countries, calling urban_coord for each, with the provided settings.
+
+    Will never return None, instead opting to raise ExhaustedSourceError on failure.
+
+    This function calls the streetview metadata endpoint - there is no quota consumed.
+    """
+    countries = random.sample(URBAN_CENTERS.keys(), k=min(country_retries, len(URBAN_CENTERS)))
+    for country in countries:
+        logger.info(f"Selecting urban centers from {country}")
+        pt = await urban_coord(country, city_retries=city_retries, point_retries=point_retries, max_dist_km=max_dist_km)
+        if pt is not None:
+            return pt
+    else:
+        raise ExhaustedSourceError()
+
+
+async def urban_coord_ensured(country_lock, max_attempts=30, city_retries=10, point_retries=10, max_dist_km=25):
+    """
+    The same behavior as urban_coord, but will make at most max_attempts cycles through the
+    behavior of urban_coord, trying to ensure a valid point is found.
+
+    Will never return None, instead opting to raise ExhaustedSourceError on failure.
+
+    This function calls the streetview metadata endpoint - there is no quota consumed.
+    """
+    for i in range(max_attempts):
+        logger.info(f"Attempt #{i + 1} to select urban centers from {country_lock}")
+        pt = await urban_coord(country_lock, city_retries=city_retries, point_retries=point_retries, max_dist_km=max_dist_km)
+        if pt is not None:
+            return pt
+    else:
+        raise ExhaustedSourceError(points)

+ 5 - 1
server/requirements.txt

@@ -1,3 +1,6 @@
+aiohttp==3.7.4.post0
+async-timeout==3.0.1
+attrs==20.3.0
 certifi==2020.12.5
 chardet==4.0.0
 click==7.1.2
@@ -8,9 +11,9 @@ h11==0.12.0
 haversine==2.3.0
 httptools==0.1.1
 idna==2.10
+multidict==5.1.0
 pydantic==1.8.1
 pyhumps==1.6.1
-requests==2.25.1
 SQLAlchemy==1.4.7
 starlette==0.13.6
 typing-extensions==3.7.4.3
@@ -18,3 +21,4 @@ urllib3==1.26.4
 uvicorn==0.13.4
 uvloop==0.15.2
 websockets==8.1
+yarl==1.6.3