Просмотр исходного кода

Add country lock, disable urban centers

Kirk Trombley 4 лет назад
Родитель
Сommit
bef6169328

+ 7 - 3
server/app/api/game.py

@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
 from .. import scoring
 from ..schemas import GameConfig, Guess
 from ..db import get_db, queries, models
-from ..gen import generate_points, restock_source
+from ..point_gen import generate_points, restock_source, ExhaustedSourceError
 
 router = APIRouter()
 
@@ -28,9 +28,13 @@ class GuessResult(CamelModel):
 
 @router.put("")
 def create_game(config: GameConfig, bg: BackgroundTasks, db: Session = Depends(get_db)):
-    coords = generate_points(config)
+    try:
+        coords = generate_points(config)
+    except ExhaustedSourceError:
+        # TODO might be worth logging something useful here eventually
+        raise HTTPException(status_code=501, detail="Sufficient points could not be generated quickly enough")
     game_id = queries.create_game(db, config, coords)
-    bg.add_task(restock_source, config)
+    bg.add_task(restock_source, config.generation_method)
     return { "gameId": game_id }
 
 

+ 2 - 2
server/app/api/other.py

@@ -6,7 +6,7 @@ from pydantic import confloat
 
 from .. import scoring
 from ..schemas import CacheInfo
-from ..gen import get_cache_info
+from ..point_gen import get_cache_info
 
 router = APIRouter()
 
@@ -32,7 +32,7 @@ class CacheResponse(CamelModel):
 
 @router.get("/health")
 def health():
-    return { "status": "healthy", "version": "2.0" }
+    return { "status": "healthy", "version": "3.0" }
 
 
 @router.post("/score", response_model=Score)

+ 1 - 1
server/app/db/models.py

@@ -13,7 +13,7 @@ class Game(Base):
     linked_game = Column(String)
     timer = Column(Integer)
     rounds = Column(Integer)
-    only_america = Column(Boolean)
+    country_lock = Column(String)
     generation_method = Column(String)
     rule_set = Column(String)
     coordinates = relationship("Coordinate", lazy=True, order_by="Coordinate.round_number")

+ 1 - 1
server/app/db/queries.py

@@ -20,7 +20,7 @@ def create_game(db: Session, conf: schemas.GameConfig, coords: List[Tuple[float,
         game_id=game_id,
         timer=conf.timer,
         rounds=conf.rounds,
-        only_america=conf.only_america,
+        country_lock=conf.country_lock,
         generation_method=conf.generation_method,
         rule_set=conf.rule_set,
     )

+ 31 - 0
server/app/point_gen/__init__.py

@@ -0,0 +1,31 @@
+from typing import List, Tuple
+
+from .random_street_view import SOURCE_GROUP as RSV_SOURCE_GROUP
+from .shared import ExhaustedSourceError
+
+from ..schemas import GameConfig, GenMethodEnum, CacheInfo
+
+source_groups = {
+    GenMethodEnum.rsv: RSV_SOURCE_GROUP,
+}
+
+
+def generate_points(config: GameConfig) -> List[Tuple[float, float]]:
+    """
+    Generate points according to the GameConfig.
+    """
+    return source_groups[config.generation_method].get_points_from(config.rounds, config.country_lock)
+
+
+def restock_source(generation_method: GenMethodEnum):
+    """
+    Restock any caches associated with the generation method.
+    """
+    source_groups[generation_method].restock_all()
+
+
+def get_cache_info() -> List[CacheInfo]:
+    """
+    Get CacheInfo for all caches
+    """
+    return [CacheInfo(cache_name=c.get_name(), size=len(c.stock)) for g in source_groups.values() for c in g.cached]

+ 63 - 0
server/app/point_gen/random_street_view.py

@@ -0,0 +1,63 @@
+import requests
+
+from .shared import point_has_streetview, GeoPointSource, CachedGeoPointSource, ExhaustedSourceError, GeoPointSourceGroup
+
+RSV_URL = "https://randomstreetview.com/data"
+
+
+def call_random_street_view(country_lock=None):
+    """
+    Returns an array of (some number of) tuples, each being (latitude, longitude).
+    All points will be valid streetview coordinates. There is no guarantee as to the
+    length of this array (it may be empty), but it will never be None.
+
+    This function calls the streetview metadata endpoint - there is no quota consumed.
+    """
+    try:
+        rsv_js = requests.post(RSV_URL, data={"country": country_lock.lower() if country_lock is not None else "all"}).json()
+    except:
+        return []
+
+    if not rsv_js["success"]:
+        return []
+    
+    return [
+        (point["lat"], point["lng"])
+        for point in rsv_js["locations"]
+        if point_has_streetview(point["lat"], point["lng"])
+    ]
+
+
+class RSVPointSource(GeoPointSource):
+    def __init__(self, country_lock=None, max_attempts=100):
+        self.country_lock = country_lock
+        self.max_attempts = max_attempts
+
+    def get_name(self):
+        return f"RSV-{self.country_lock or 'all'}"
+
+    def get_points(self, n):
+        attempts = 0
+        points = []
+        while len(points) < n:
+            if attempts > self.max_attempts:
+                raise ExhaustedSourceError()
+            points.extend(call_random_street_view(country_lock=self.country_lock))
+            attempts += 1
+        return points
+
+
+WORLD_SOURCE = CachedGeoPointSource(RSVPointSource(), 10)
+VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw", 
+                   "br", "bg", "kh", "ca", "cl", "hr", "co", 
+                   "cz", "dk", "ae", "ee", "fi", "fr", "de", 
+                   "gr", "hu", "hk", "is", "id", "ie", "it", 
+                   "il", "jp", "lv", "lt", "my", "mx", "nl", 
+                   "nz", "no", "pe", "pl", "pt", "ro", "ru", 
+                   "sg", "sk", "si", "za", "kr", "es", "sz", 
+                   "se", "ch", "tw", "th", "ua", "gb", "us")
+COUNTRY_SOURCES = {
+    "us": CachedGeoPointSource(RSVPointSource("us"), 10),   # cache US specifically since it is commonly used
+    **{ k: RSVPointSource(k) for k in VALID_COUNTRIES if k not in ("us",) }
+}
+SOURCE_GROUP = GeoPointSourceGroup(COUNTRY_SOURCES, WORLD_SOURCE)

+ 128 - 0
server/app/point_gen/shared.py

@@ -0,0 +1,128 @@
+from typing import List, Tuple, Union, Dict
+import collections
+import logging
+
+import requests
+
+# Google API key, with access to Street View Static API
+# this can be safely committed due to permission restriction
+google_api_key = "AIzaSyAqjCYR6Szph0X0H_iD6O1HenFhL9jySOo"
+metadata_url = "https://maps.googleapis.com/maps/api/streetview/metadata"
+
+logger = logging.getLogger(__name__)
+
+
+def point_has_streetview(lat, lng):
+    """
+    Returns True if the streetview metadata endpoint says a given point has
+    data available, and False otherwise.
+
+    This function calls the streetview metadata endpoint - there is no quota consumed.
+    """
+    return requests.get(metadata_url, params={
+        "key": google_api_key,
+        "location": f"{lat},{lng}",
+    }).json()["status"] == "OK"
+
+
+class ExhaustedSourceError(Exception):
+    pass
+
+
+class GeoPointSource:
+    """
+    Abstract base class for a source of geo points
+    """
+    def get_name(self) -> str:
+        """
+        Return a human-readable name for this point source, for debugging purposes.
+        """
+        raise NotImplemented("Must be implemented by subclasses")
+
+    def get_points(self, n: int) -> List[Tuple[float, float]]:
+        """
+        Return a list of at least n valid geo points, as (latitude, longitude) pairs.
+        In the event that the GeoPointSource cannot reasonably supply enough points,
+        most likely due to time constraints, it should raise an ExhaustedSourceError.
+        """
+        raise NotImplemented("Must be implemented by subclasses")
+
+
+class CachedGeoPointSource(GeoPointSource):
+    """
+    Wrapper tool for maintaing a cache of points from a GeoPointSource to
+    make get_points faster, at the exchange of needing to restock those
+    points after the fact. This can be done in another thread, however, to
+    hide this cost from the user.
+    """
+
+    def __init__(self, source: GeoPointSource, stock_target: int):
+        self.source = source
+        self.stock = collections.deque()
+        self.stock_target = stock_target
+
+    def get_name(self):
+        return f"Cached({self.source.get_name()}, {self.stock_target})"
+
+    def restock(self, n: Union[int, None] = None):
+        """
+        Restock at least n points into this source.
+        If n is not provided, it will default to stock_target, as set during the
+        construction of this point source.
+        """
+        n = n if n is not None else self.stock_target - len(self.stock)
+        if n > 0:
+            logger.info(f"Restocking {type(self).__name__} with {n} points")
+            pts = self.source.get_points(n)
+            self.stock.extend(pts)
+            diff = n - len(pts)
+            if diff > 0:
+                # if implementations of source.get_points are well behaved, this will
+                # never actually need to recurse to finish the job.
+                self.restock(n=diff)
+            logger.info(f"Finished restocking {type(self).__name__}")
+
+    def get_points(self, n: int) -> List[Tuple[float, float]]:
+        """
+        Pull n points from the current stock.
+        It is recommended to call CachedGeoPointSource.restock after this, to ensure 
+        the stock is not depleted. If possible, calling restock in another thread is
+        recommended, as it can be a long operation depending on implementation.
+        """
+        if len(self.stock) >= n:
+            pts = []
+            for _ in range(n):
+                pts.append(self.stock.popleft())
+            return pts
+        self.restock(n=n)
+        # this is safe as long as restock does actually add enough new points.
+        # unless this object is being rapidly drained by another thread,
+        # this will recur at most once.
+        return self.get_points(n=n)
+
+
+class GeoPointSourceGroup:
+    """
+    Container of multiple GeoPointSources, each with some key.
+    """
+    def __init__(self, sources: Dict[str, GeoPointSource], default: GeoPointSource):
+        self.sources = sources
+        self.default = default
+        self.cached = [s for s in sources.values() if isinstance(s, CachedGeoPointSource)]
+        if isinstance(default, CachedGeoPointSource):
+            self.cached.append(default)
+
+    def restock_all(self):
+        """
+        Restock any and all CachedGeoPointSources managed by this group.
+        """
+        for s in self.cached:
+            s.restock()
+
+    def get_points_from(self, n: int, key: Union[str, None] = None) -> List[Tuple[float, float]]:
+        """
+        Return a list of at least n valid geo points, as (latitude, longitude) pairs,
+        for a given key. If no key is provided, or no matching GeoPointSource is found,
+        the default GeoPointSource will be used.
+        """
+        return self.sources.get(key, self.default).get_points(n)

+ 6 - 6
server/app/schemas.py

@@ -1,13 +1,14 @@
 from enum import Enum
+from typing import Union
 
 from fastapi_camelcase import CamelModel
 from pydantic import conint, confloat
 
 
 class GenMethodEnum(str, Enum):
-    map_crunch = "MAPCRUNCH"
+    # map_crunch = "MAPCRUNCH"
     rsv = "RANDOMSTREETVIEW"
-    urban = "URBAN"
+    # urban = "URBAN"
 
 
 class RuleSetEnum(str, Enum):
@@ -20,8 +21,8 @@ class RuleSetEnum(str, Enum):
 class GameConfig(CamelModel):
     timer: conint(gt=0)
     rounds: conint(gt=0)
-    only_america: bool = False
-    generation_method: GenMethodEnum = GenMethodEnum.map_crunch
+    country_lock: Union[str, None] = None
+    generation_method: GenMethodEnum = GenMethodEnum.rsv
     rule_set: RuleSetEnum = RuleSetEnum.normal
 
     class Config:
@@ -35,6 +36,5 @@ class Guess(CamelModel):
 
 
 class CacheInfo(CamelModel):
-    generation_method: str
-    only_america: bool
+    cache_name: str
     size: int