Эх сурвалжийг харах

Rewrite of point gen logic for speed and caching

Kirk Trombley 4 жил өмнө
parent
commit
c355004ba1

+ 8 - 1
server/app/__init__.py

@@ -1,12 +1,13 @@
 import os
 import logging
+import asyncio
 
 from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 
 from .api import other, game
 from .db import init_db
-from .point_gen import aiohttp_client
+from .point_gen import aiohttp_client, points
 
 logging.config.fileConfig('logging.conf', disable_existing_loggers=False)
 
@@ -24,11 +25,17 @@ app.add_middleware(
 )
 
 
+restocking_task = None
+
+
 @app.on_event("startup")
 def startup():
     init_db(os.environ.get("SQLALCHEMY_URL", "sqlite:////tmp/terrassumptions.db"), connect_args={"check_same_thread": False})
+    restocking_task = asyncio.create_task(points.restock_all())
 
 
 @app.on_event("shutdown")
 async def shutdown_event():
     await aiohttp_client.close()
+    if restocking_task is not None and not restocking_task.done():
+        restocking_task.cancel()

+ 103 - 32
server/app/point_gen/__init__.py

@@ -1,14 +1,17 @@
 import asyncio
 import collections
+import logging
 import random
 from typing import List, Tuple, Dict, Union
 
 from .random_street_view import call_random_street_view, VALID_COUNTRIES as RSV_COUNTRIES
-from .urban_centers import urban_coord_unlocked, urban_coord_ensured, VALID_COUNTRIES as URBAN_COUNTRIES
-from .shared import ExhaustedSourceError, aiohttp_client
+from .urban_centers import urban_coord, VALID_COUNTRIES as URBAN_COUNTRIES
+from .shared import aiohttp_client
 
 from ..schemas import GameConfig, GenMethodEnum, CountryCode, CacheInfo, GeneratorInfo
 
+logger = logging.getLogger(__name__)
+
 generator_info = [
     GeneratorInfo(
         generation_method=GenMethodEnum.rsv,
@@ -26,47 +29,96 @@ cache_names = {
 }
 
 
+class ExhaustedSourceError(Exception):
+    pass
+
+
 class PointStore:
-    def __init__(self, cache_targets: Dict[Tuple[GenMethodEnum, CountryCode], int]):
+    def __init__(self, 
+                 cache_targets: Dict[Tuple[GenMethodEnum, CountryCode], int],
+                 rsv_country_retries: int = 5,
+                 urban_country_pool_size: int = 30,
+                 urban_country_retries: int = 30,
+                 urban_city_retries: int = 50):
         self.cache_targets = cache_targets
+        self.rsv_country_retries = rsv_country_retries
+        self.urban_country_pool_size = urban_country_pool_size
+        self.urban_country_retries = urban_country_retries
+        self.urban_city_retries = urban_city_retries
         self.store = collections.defaultdict(collections.deque)
 
-    async def generate_point(self, generator: GenMethodEnum, country: Union[CountryCode, None]) -> Tuple[str, float, float]:
-        if generator == GenMethodEnum.rsv:
-            # RSV point functions return a collection of points, which should be cached
-            point, *points = await call_random_street_view(country)
-            # use the country on the point - since country itself might be None
-            self.store[(generator, point[0])].extend(points)
-            return point
-        elif generator == GenMethodEnum.urban:
-            # urban center point functions only return a single point
-            if country is None:
-                return await urban_coord_unlocked()
-            return await urban_coord_ensured(country, city_retries=50)
+    async def get_point(self, generator: GenMethodEnum, country: Union[CountryCode, None], force_generate: bool = False) -> Tuple[str, float, float]:
+        if country is None:
+            # generating points across the whole world
+            # for current generators, this means selecting a country at random
+            if generator == GenMethodEnum.rsv:
+                for _ in range(self.rsv_country_retries):
+                    # try a few countries before giving up, just in case one has no data
+                    country = random.choice(RSV_COUNTRIES)
+                    # RSV point function returns a collection of points, which should be cached
+                    points = await call_random_street_view(country)
+                    if len(points) > 0:
+                        point = points.pop()
+                        self.store[(generator, country)].extend(points)
+                        return point
+                else:
+                    raise ExhaustedSourceError
+            elif generator == GenMethodEnum.urban:
+                # try many countries since finding an urban center point is harder
+                countries = random.sample(URBAN_COUNTRIES, k=min(self.urban_country_pool_size, len(URBAN_COUNTRIES)))
+                for country in countries:
+                    logger.info(f"Selecting urban centers from {country}")
+                    pt = await urban_coord(country)
+                    if pt is not None:
+                        return pt
+                else:
+                    raise ExhaustedSourceError
+            else:
+                raise ExhaustedSourceError
         else:
-            raise ExhaustedSourceError
-
-    async def get_point(self, generator: GenMethodEnum, country: Union[CountryCode, None]) -> Tuple[str, float, float]:
-        if country is not None:
+            # generating points for a specific country
             # if we already have a point ready, just return it immediately
-            # to avoid bias, we only do this in country-locking mode
-            stock = self.store[(generator, country)]
-            if len(stock) > 0:
-                return stock.popleft()
-
-        return await self.generate_point(generator, country)
+            if not force_generate:
+                stock = self.store[(generator, country)]
+                if len(stock) > 0:
+                    return stock.popleft()
+            
+            # otherwise, need to actually generate a new point
+            if generator == GenMethodEnum.rsv:
+                # RSV point function returns a collection of points, which should be cached
+                points = await call_random_street_view(country)
+                if len(points) == 0:
+                    raise ExhaustedSourceError
+                point = points.pop()
+                self.store[(generator, country)].extend(points)
+                return point
+            elif generator == GenMethodEnum.urban:
+                for i in range(self.urban_country_retries):
+                    logger.info(f"Attempt #{i + 1} to select urban centers from {country}")
+                    pt = await urban_coord(country, city_retries=self.urban_city_retries)
+                    if pt is not None:
+                        return pt
+                else:
+                    raise ExhaustedSourceError
+            else:
+                raise ExhaustedSourceError
 
     async def get_points(self, config: GameConfig) -> List[Tuple[str, float, float]]:
         """
         Provide points according to the GameConfig.
 
-        Return a list of at least n valid geo points, as 
+        Return a list of valid geo points, as 
         (2 character country code, latitude, longitude) tuples.
 
         In the event that the configured source cannot reasonably supply enough points,
         most likely due to time constraints, this will raise an ExhaustedSourceError.
         """
-        return await asyncio.gather(*[self.get_point(config.generation_method, config.country_lock) for _ in range(config.rounds)])
+        try:
+            point_tasks = [self.get_point(config.generation_method, config.country_lock) for _ in range(config.rounds)]
+            gathered = asyncio.gather(*point_tasks)
+            return await asyncio.wait_for(gathered, 60)
+        except asyncio.TimeoutError:
+            raise ExhaustedSourceError
 
     def get_cache_info(self) -> List[CacheInfo]:
         """
@@ -74,17 +126,36 @@ class PointStore:
         """
         return [CacheInfo(cache_name=f"{cache_names[g]}-{c}", size=len(ps)) for (g, c), ps in self.store.items()]
 
+    async def _restock_source_impl(self, generator: GenMethodEnum, country: CountryCode):
+        key = (generator, country)
+        target = self.cache_targets.get(key, 0)
+        stock = self.store[key]
+        while len(stock) < target:  # this check allows for RSV to do its multi-point restock
+            stock.append(await self.get_point(*key, force_generate=True))
+
     async def restock_source(self, config: GameConfig):
         """
         Restock any caches associated with the GameConfig.
         """
         if config.country_lock is None:
             return
-        key = (config.generation_method, config.country_lock)
-        target = self.cache_targets.get(key, 0)
-        stock = self.store[key]
-        while len(stock) < target:
-            stock.append(await self.generate_point(*key))
+        try:
+            await self._restock_source_impl(config.generation_method, config.country_lock)
+        except ExhaustedSourceError:
+            # if the cache can't be restocked, that is bad, but not fatal
+            logger.exception(f"Failed to fully restock point cache for {config}")
+
+    async def restock_all(self, timeout: Union[int, float, None] = None):
+        """
+        Restock all caches.
+        """
+        restock_tasks = [self._restock_source_impl(gen, cc) for (gen, cc) in self.cache_targets.keys()]
+        gathered = asyncio.gather(*restock_tasks)
+        try:
+            await asyncio.wait_for(gathered, timeout)
+        except (asyncio.TimeoutError, ExhaustedSourceError):
+            # if this task times out, it's fine, as it's just intended to be a best effort
+            logger.exception(f"Failed to fully restock point cache for {config}")
 
 
 points = PointStore({

+ 21 - 30
server/app/point_gen/random_street_view.py

@@ -1,7 +1,8 @@
 import random
 import logging
+import asyncio
 
-from .shared import aiohttp_client, point_has_streetview, ExhaustedSourceError
+from .shared import aiohttp_client, point_has_streetview
 
 RSV_URL = "https://randomstreetview.com/data"
 VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw", 
@@ -16,39 +17,29 @@ VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw",
 logger = logging.getLogger(__name__)
 
 
-async def call_random_street_view(country_lock, max_attempts=5):
+async def call_random_street_view(country_lock):
     """
     Returns an array of (some number of) tuples, each being (latitude, longitude).
     All points will be valid streetview coordinates, in the country indicated by
-    country_lock. If the country_lock provided is None, a valid one is chosen at 
-    random. The returned array will never be empty. If after max_attempts no points 
-    can be found (which is very rare), this raises an ExhaustedSourceError.
+    country_lock. No size guarantee is given on the returned array - it could be empty.
 
     This function calls the streetview metadata endpoint - there is no quota consumed.
     """
-
-    if country_lock is None:
-        country_lock = random.choice(VALID_COUNTRIES)
     
-    for _ in range(max_attempts):
-        logger.info("Attempting RSV...")
-        try:
-            async with aiohttp_client.post(RSV_URL, data={"country": country_lock.lower()}) as response:
-                rsv_js = await response.json(content_type=None)
-                logger.info(f"Got back {rsv_js.keys()}")
-        except:
-            logger.exception("Failed RSV")
-            continue
-
-        if not rsv_js["success"]:
-            continue
-        
-        points = [
-            (country_lock, point["lat"], point["lng"])
-            for point in rsv_js["locations"]
-            if await point_has_streetview(point["lat"], point["lng"])
-        ]
-        if len(points) > 0:
-            return points
-    else:
-        raise ExhaustedSourceError
+    try:
+        async with aiohttp_client.post(RSV_URL, data={"country": country_lock.lower()}) as response:
+            rsv_js = await response.json(content_type=None)
+    except:
+        logger.exception("Failed RSV call")
+        return []
+
+    if not rsv_js["success"]:
+        return []
+
+    points = []
+    async def add_point_if_valid(point):
+        if await point_has_streetview(point["lat"], point["lng"]):
+            points.append((country_lock, point["lat"], point["lng"]))
+    
+    await asyncio.gather(*[add_point_if_valid(p) for p in rsv_js["locations"]])
+    return points

+ 0 - 4
server/app/point_gen/shared.py

@@ -30,7 +30,3 @@ async def point_has_streetview(lat, lng):
     async with aiohttp_client.get(metadata_url, params=params) as response:
         body = await response.json()
         return body["status"] == "OK"
-
-
-class ExhaustedSourceError(Exception):
-    pass

+ 1 - 39
server/app/point_gen/urban_centers.py

@@ -2,10 +2,9 @@ import math
 import random
 import csv
 import logging
-import asyncio
 from collections import defaultdict
 
-from .shared import point_has_streetview, ExhaustedSourceError
+from .shared import point_has_streetview
 from ..scoring import mean_earth_radius_km
 
 logger = logging.getLogger(__name__)
@@ -63,40 +62,3 @@ async def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_
             if await point_has_streetview(pt_lat, pt_lng):
                 logger.info("Point found!")
                 return (country_lock, pt_lat, pt_lng)
-
-
-async def urban_coord_unlocked(country_retries=30, city_retries=10, point_retries=10, max_dist_km=25):
-    """
-    The same behavior as urban_coord, but for a randomly chosen country. Will attempt at most
-    country_retries countries, calling urban_coord for each, with the provided settings.
-
-    Will never return None, instead opting to raise ExhaustedSourceError on failure.
-
-    This function calls the streetview metadata endpoint - there is no quota consumed.
-    """
-    countries = random.sample(URBAN_CENTERS.keys(), k=min(country_retries, len(URBAN_CENTERS)))
-    for country in countries:
-        logger.info(f"Selecting urban centers from {country}")
-        pt = await urban_coord(country, city_retries=city_retries, point_retries=point_retries, max_dist_km=max_dist_km)
-        if pt is not None:
-            return pt
-    else:
-        raise ExhaustedSourceError
-
-
-async def urban_coord_ensured(country_lock, max_attempts=30, city_retries=10, point_retries=10, max_dist_km=25):
-    """
-    The same behavior as urban_coord, but will make at most max_attempts cycles through the
-    behavior of urban_coord, trying to ensure a valid point is found.
-
-    Will never return None, instead opting to raise ExhaustedSourceError on failure.
-
-    This function calls the streetview metadata endpoint - there is no quota consumed.
-    """
-    for i in range(max_attempts):
-        logger.info(f"Attempt #{i + 1} to select urban centers from {country_lock}")
-        pt = await urban_coord(country_lock, city_retries=city_retries, point_retries=point_retries, max_dist_km=max_dist_km)
-        if pt is not None:
-            return pt
-    else:
-        raise ExhaustedSourceError