Browse Source

Track country codes with geopoints

Kirk Trombley 4 years ago
parent
commit
c3fcd871c7

+ 5 - 1
README.md

@@ -69,6 +69,7 @@ GET /game/{game_id}/coords
         "1": {
             "lat": number,
             "lng": number,
+            "country": string || null,
         }, ...
     }
 POST /game/{game_id}/join
@@ -89,6 +90,7 @@ GET /game/{game_id}/players
                     "1": {
                         "lat": number,
                         "lng": number,
+                        "country": string || null,
                         "score": number || null,
                         "timeRemaining": number
                     }, ...
@@ -102,6 +104,7 @@ GET /game/{game_id}/players/{player_id}/current
         "coord": {
             "lat": number,
             "lng": number,
+            "country": string || null,
         } || null,
         "timer": number || null
     }
@@ -122,7 +125,8 @@ POST /game/{game_id}/round/{round}/guess/{player_id}
     Accepts {
         "timeRemaining": number,
         "lat": number,
-        "lng": number
+        "lng": number,
+        "country": string || null (default: null),
     }
     Returns (404, 409) vs 201 and {
         "totalScore": number,

+ 3 - 0
server/app/api/game.py

@@ -56,6 +56,7 @@ def get_game_coords(game: models.Game = Depends(get_game)):
         str(coord.round_number): {
             "lat": coord.latitude,
             "lng": coord.longitude,
+            "country": coord.country_code,
         }
         for coord in game.coordinates
     }
@@ -81,6 +82,7 @@ def get_players(game: models.Game = Depends(get_game)):
                 str(g.round_number): {
                     "lat": g.latitude,
                     "lng": g.longitude,
+                    "country": g.country_code,
                     "score": g.round_score,
                     "timeRemaining": g.time_remaining,
                 } for g in p.guesses
@@ -112,6 +114,7 @@ def get_current_round(db: Session = Depends(get_db), player: models.Player = Dep
         "coord": {
             "lat": coord.latitude,
             "lng": coord.longitude,
+            "country": coord.country_code,
         },
         "timer": queries.get_next_round_time(player),
     }

+ 2 - 0
server/app/db/models.py

@@ -32,6 +32,7 @@ class Coordinate(Base):
     __tablename__ = "coordinate"
     game_id = Column(String, ForeignKey("game.game_id"), primary_key=True)
     round_number = Column(Integer, primary_key=True, autoincrement=False)
+    country_code = Column(String)
     latitude = Column(Float)
     longitude = Column(Float)
 
@@ -42,6 +43,7 @@ class Guess(Base):
     round_number = Column(Integer, primary_key=True, autoincrement=False)
     latitude = Column(Float)
     longitude = Column(Float)
+    country_code = Column(String)
     round_score = Column(Integer)
     time_remaining = Column(Float)
     created_at = Column(DateTime, default=datetime.utcnow)

+ 4 - 2
server/app/db/queries.py

@@ -7,7 +7,7 @@ from .models import Game, Coordinate, Player, Guess
 from .. import schemas
 
 
-def create_game(db: Session, conf: schemas.GameConfig, coords: List[Tuple[float, float]]) -> str:
+def create_game(db: Session, conf: schemas.GameConfig, coords: List[Tuple[str, float, float]]) -> str:
     if len(coords) != conf.rounds:
         raise ValueError("Insufficient number of coordinates")
     
@@ -28,9 +28,10 @@ def create_game(db: Session, conf: schemas.GameConfig, coords: List[Tuple[float,
     db.add_all([Coordinate(
         game_id=game_id,
         round_number=round_num + 1,
+        country_code=cc,
         latitude=lat,
         longitude=lng,
-    ) for (round_num, (lat, lng)) in enumerate(coords)])
+    ) for (round_num, (cc, lat, lng)) in enumerate(coords)])
     db.commit()
 
     return game_id
@@ -100,6 +101,7 @@ def add_guess(db: Session, guess: schemas.Guess, player: Player, round_number: i
         round_number=round_number,
         latitude=guess.lat,
         longitude=guess.lng,
+        country_code=guess.country,
         round_score=score,
         time_remaining=guess.time_remaining,
     )

+ 1 - 1
server/app/point_gen/__init__.py

@@ -12,7 +12,7 @@ source_groups = {
 }
 
 
-def generate_points(config: GameConfig) -> List[Tuple[float, float]]:
+def generate_points(config: GameConfig) -> List[Tuple[str, float, float]]:
     """
     Generate points according to the GameConfig.
     """

+ 14 - 10
server/app/point_gen/random_street_view.py

@@ -1,8 +1,18 @@
+import random
+
 import requests
 
 from .shared import point_has_streetview, GeoPointSource, CachedGeoPointSource, ExhaustedSourceError, GeoPointSourceGroup
 
 RSV_URL = "https://randomstreetview.com/data"
+VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw", 
+                   "br", "bg", "kh", "ca", "cl", "hr", "co", 
+                   "cz", "dk", "ae", "ee", "fi", "fr", "de", 
+                   "gr", "hu", "hk", "is", "id", "ie", "it", 
+                   "il", "jp", "lv", "lt", "my", "mx", "nl", 
+                   "nz", "no", "pe", "pl", "pt", "ro", "ru", 
+                   "sg", "sk", "si", "za", "kr", "es", "sz", 
+                   "se", "ch", "tw", "th", "ua", "gb", "us")
 
 
 def call_random_street_view(country_lock=None):
@@ -13,8 +23,10 @@ def call_random_street_view(country_lock=None):
 
     This function calls the streetview metadata endpoint - there is no quota consumed.
     """
+    if country_lock is None:
+        country_lock = random.choice(VALID_COUNTRIES)
     try:
-        rsv_js = requests.post(RSV_URL, data={"country": country_lock.lower() if country_lock is not None else "all"}).json()
+        rsv_js = requests.post(RSV_URL, data={"country": country_lock.lower()}).json()
     except:
         return []
 
@@ -22,7 +34,7 @@ def call_random_street_view(country_lock=None):
         return []
     
     return [
-        (point["lat"], point["lng"])
+        (country_lock, point["lat"], point["lng"])
         for point in rsv_js["locations"]
         if point_has_streetview(point["lat"], point["lng"])
     ]
@@ -48,14 +60,6 @@ class RSVPointSource(GeoPointSource):
 
 
 WORLD_SOURCE = CachedGeoPointSource(RSVPointSource(), 10)
-VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw", 
-                   "br", "bg", "kh", "ca", "cl", "hr", "co", 
-                   "cz", "dk", "ae", "ee", "fi", "fr", "de", 
-                   "gr", "hu", "hk", "is", "id", "ie", "it", 
-                   "il", "jp", "lv", "lt", "my", "mx", "nl", 
-                   "nz", "no", "pe", "pl", "pt", "ro", "ru", 
-                   "sg", "sk", "si", "za", "kr", "es", "sz", 
-                   "se", "ch", "tw", "th", "ua", "gb", "us")
 COUNTRY_SOURCES = {
     "us": CachedGeoPointSource(RSVPointSource("us"), 10),   # cache US specifically since it is commonly used
     **{ k: RSVPointSource(k) for k in VALID_COUNTRIES if k not in ("us",) }

+ 8 - 7
server/app/point_gen/shared.py

@@ -40,9 +40,10 @@ class GeoPointSource:
         """
         raise NotImplemented("Must be implemented by subclasses")
 
-    def get_points(self, n: int) -> List[Tuple[float, float]]:
+    def get_points(self, n: int) -> List[Tuple[str, float, float]]:
         """
-        Return a list of at least n valid geo points, as (latitude, longitude) pairs.
+        Return a list of at least n valid geo points, as 
+        (2 character country code, latitude, longitude) tuples.
         In the event that the GeoPointSource cannot reasonably supply enough points,
         most likely due to time constraints, it should raise an ExhaustedSourceError.
         """
@@ -86,7 +87,7 @@ class CachedGeoPointSource(GeoPointSource):
                 self.restock(n=diff)
             logger.info(f"Finished restocking {type(self).__name__}")
 
-    def get_points(self, n: int) -> List[Tuple[float, float]]:
+    def get_points(self, n: int) -> List[Tuple[str, float, float]]:
         """
         Pull n points from the current stock.
         It is recommended to call CachedGeoPointSource.restock after this, to ensure 
@@ -125,10 +126,10 @@ class GeoPointSourceGroup:
         if isinstance(src, CachedGeoPointSource):
             src.restock()
 
-    def get_points_from(self, n: int, key: Union[str, None] = None) -> List[Tuple[float, float]]:
+    def get_points_from(self, n: int, key: Union[str, None] = None) -> List[Tuple[str, float, float]]:
         """
-        Return a list of at least n valid geo points, as (latitude, longitude) pairs,
-        for a given key. If no key is provided, or no matching GeoPointSource is found,
-        the default GeoPointSource will be used.
+        Return a list of at least n valid geo points, for a given key. If no key 
+        is provided, or no matching GeoPointSource is found, the default 
+        GeoPointSource will be used.
         """
         return self.sources.get(key, self.default).get_points(n)

+ 4 - 2
server/app/point_gen/urban_centers.py

@@ -19,6 +19,8 @@ with open("./data/urban-centers.csv") as infile:
         _found_countries.add(code)
         _urban_center_count += 1
 logger.info(f"Read {_urban_center_count} urban centers from {len(_found_countries)} countries.")
+VALID_COUNTRIES = tuple(_found_countries)
+
 
 def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_km=25):
     """
@@ -33,6 +35,7 @@ def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_km=25)
     This function calls the streetview metadata endpoint - there is no quota consumed.
     """
 
+    country_lock = country_lock.lower()
     cities = URBAN_CENTERS[country_lock]
     src = random.sample(cities, k=min(city_retries, len(cities)))
 
@@ -59,7 +62,7 @@ def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_km=25)
             pt_lng = math.degrees(pt_lng_rad)
             if point_has_streetview(pt_lat, pt_lng):
                 logger.info("Point found!")
-                return (pt_lat, pt_lng)
+                return (country_lock, pt_lat, pt_lng)
 
 
 class WorldUrbanPointSource(GeoPointSource):
@@ -123,7 +126,6 @@ class CountryUrbanSourceDict(dict):
 
 
 WORLD_SOURCE = CachedGeoPointSource(WorldUrbanPointSource(), 20)
-VALID_COUNTRIES = tuple(_found_countries)
 COUNTRY_SOURCES = CountryUrbanSourceDict()
 COUNTRY_SOURCES["us"] = CachedGeoPointSource(CountryUrbanPointSource("us"), 20) # cache US
 SOURCE_GROUP = GeoPointSourceGroup(COUNTRY_SOURCES, WORLD_SOURCE)

+ 6 - 2
server/app/schemas.py

@@ -5,6 +5,9 @@ from fastapi_camelcase import CamelModel
 from pydantic import conint, confloat, constr
 
 
+CountryCode = constr(to_lower=True, min_length=2, max_length=2)
+
+
 class GenMethodEnum(str, Enum):
     # map_crunch = "MAPCRUNCH"
     rsv = "RANDOMSTREETVIEW"
@@ -21,7 +24,7 @@ class RuleSetEnum(str, Enum):
 class GameConfig(CamelModel):
     timer: conint(gt=0)
     rounds: conint(gt=0)
-    country_lock: Union[constr(to_lower=True, min_length=2, max_length=2), None] = None
+    country_lock: Union[CountryCode, None] = None
     generation_method: GenMethodEnum = GenMethodEnum.rsv
     rule_set: RuleSetEnum = RuleSetEnum.normal
 
@@ -33,6 +36,7 @@ class Guess(CamelModel):
     lat: confloat(ge=-90.0, le=90.0)
     lng: confloat(ge=-180.0, le=180.0)
     time_remaining: int
+    country: Union[CountryCode, None] = None
 
 
 class CacheInfo(CamelModel):
@@ -42,4 +46,4 @@ class CacheInfo(CamelModel):
 
 class GeneratorInfo(CamelModel):
     generation_method: GenMethodEnum
-    country_locks: List[constr(to_lower=True, min_length=2, max_length=2)]
+    country_locks: List[CountryCode]