Эх сурвалжийг харах

Track country codes with geopoints

Kirk Trombley 4 жил өмнө
parent
commit
c3fcd871c7

+ 5 - 1
README.md

@@ -69,6 +69,7 @@ GET /game/{game_id}/coords
         "1": {
         "1": {
             "lat": number,
             "lat": number,
             "lng": number,
             "lng": number,
+            "country": string || null,
         }, ...
         }, ...
     }
     }
 POST /game/{game_id}/join
 POST /game/{game_id}/join
@@ -89,6 +90,7 @@ GET /game/{game_id}/players
                     "1": {
                     "1": {
                         "lat": number,
                         "lat": number,
                         "lng": number,
                         "lng": number,
+                        "country": string || null,
                         "score": number || null,
                         "score": number || null,
                         "timeRemaining": number
                         "timeRemaining": number
                     }, ...
                     }, ...
@@ -102,6 +104,7 @@ GET /game/{game_id}/players/{player_id}/current
         "coord": {
         "coord": {
             "lat": number,
             "lat": number,
             "lng": number,
             "lng": number,
+            "country": string || null,
         } || null,
         } || null,
         "timer": number || null
         "timer": number || null
     }
     }
@@ -122,7 +125,8 @@ POST /game/{game_id}/round/{round}/guess/{player_id}
     Accepts {
     Accepts {
         "timeRemaining": number,
         "timeRemaining": number,
         "lat": number,
         "lat": number,
-        "lng": number
+        "lng": number,
+        "country": string || null (default: null),
     }
     }
     Returns (404, 409) vs 201 and {
     Returns (404, 409) vs 201 and {
         "totalScore": number,
         "totalScore": number,

+ 3 - 0
server/app/api/game.py

@@ -56,6 +56,7 @@ def get_game_coords(game: models.Game = Depends(get_game)):
         str(coord.round_number): {
         str(coord.round_number): {
             "lat": coord.latitude,
             "lat": coord.latitude,
             "lng": coord.longitude,
             "lng": coord.longitude,
+            "country": coord.country_code,
         }
         }
         for coord in game.coordinates
         for coord in game.coordinates
     }
     }
@@ -81,6 +82,7 @@ def get_players(game: models.Game = Depends(get_game)):
                 str(g.round_number): {
                 str(g.round_number): {
                     "lat": g.latitude,
                     "lat": g.latitude,
                     "lng": g.longitude,
                     "lng": g.longitude,
+                    "country": g.country_code,
                     "score": g.round_score,
                     "score": g.round_score,
                     "timeRemaining": g.time_remaining,
                     "timeRemaining": g.time_remaining,
                 } for g in p.guesses
                 } for g in p.guesses
@@ -112,6 +114,7 @@ def get_current_round(db: Session = Depends(get_db), player: models.Player = Dep
         "coord": {
         "coord": {
             "lat": coord.latitude,
             "lat": coord.latitude,
             "lng": coord.longitude,
             "lng": coord.longitude,
+            "country": coord.country_code,
         },
         },
         "timer": queries.get_next_round_time(player),
         "timer": queries.get_next_round_time(player),
     }
     }

+ 2 - 0
server/app/db/models.py

@@ -32,6 +32,7 @@ class Coordinate(Base):
     __tablename__ = "coordinate"
     __tablename__ = "coordinate"
     game_id = Column(String, ForeignKey("game.game_id"), primary_key=True)
     game_id = Column(String, ForeignKey("game.game_id"), primary_key=True)
     round_number = Column(Integer, primary_key=True, autoincrement=False)
     round_number = Column(Integer, primary_key=True, autoincrement=False)
+    country_code = Column(String)
     latitude = Column(Float)
     latitude = Column(Float)
     longitude = Column(Float)
     longitude = Column(Float)
 
 
@@ -42,6 +43,7 @@ class Guess(Base):
     round_number = Column(Integer, primary_key=True, autoincrement=False)
     round_number = Column(Integer, primary_key=True, autoincrement=False)
     latitude = Column(Float)
     latitude = Column(Float)
     longitude = Column(Float)
     longitude = Column(Float)
+    country_code = Column(String)
     round_score = Column(Integer)
     round_score = Column(Integer)
     time_remaining = Column(Float)
     time_remaining = Column(Float)
     created_at = Column(DateTime, default=datetime.utcnow)
     created_at = Column(DateTime, default=datetime.utcnow)

+ 4 - 2
server/app/db/queries.py

@@ -7,7 +7,7 @@ from .models import Game, Coordinate, Player, Guess
 from .. import schemas
 from .. import schemas
 
 
 
 
-def create_game(db: Session, conf: schemas.GameConfig, coords: List[Tuple[float, float]]) -> str:
+def create_game(db: Session, conf: schemas.GameConfig, coords: List[Tuple[str, float, float]]) -> str:
     if len(coords) != conf.rounds:
     if len(coords) != conf.rounds:
         raise ValueError("Insufficient number of coordinates")
         raise ValueError("Insufficient number of coordinates")
     
     
@@ -28,9 +28,10 @@ def create_game(db: Session, conf: schemas.GameConfig, coords: List[Tuple[float,
     db.add_all([Coordinate(
     db.add_all([Coordinate(
         game_id=game_id,
         game_id=game_id,
         round_number=round_num + 1,
         round_number=round_num + 1,
+        country_code=cc,
         latitude=lat,
         latitude=lat,
         longitude=lng,
         longitude=lng,
-    ) for (round_num, (lat, lng)) in enumerate(coords)])
+    ) for (round_num, (cc, lat, lng)) in enumerate(coords)])
     db.commit()
     db.commit()
 
 
     return game_id
     return game_id
@@ -100,6 +101,7 @@ def add_guess(db: Session, guess: schemas.Guess, player: Player, round_number: i
         round_number=round_number,
         round_number=round_number,
         latitude=guess.lat,
         latitude=guess.lat,
         longitude=guess.lng,
         longitude=guess.lng,
+        country_code=guess.country,
         round_score=score,
         round_score=score,
         time_remaining=guess.time_remaining,
         time_remaining=guess.time_remaining,
     )
     )

+ 1 - 1
server/app/point_gen/__init__.py

@@ -12,7 +12,7 @@ source_groups = {
 }
 }
 
 
 
 
-def generate_points(config: GameConfig) -> List[Tuple[float, float]]:
+def generate_points(config: GameConfig) -> List[Tuple[str, float, float]]:
     """
     """
     Generate points according to the GameConfig.
     Generate points according to the GameConfig.
     """
     """

+ 14 - 10
server/app/point_gen/random_street_view.py

@@ -1,8 +1,18 @@
+import random
+
 import requests
 import requests
 
 
 from .shared import point_has_streetview, GeoPointSource, CachedGeoPointSource, ExhaustedSourceError, GeoPointSourceGroup
 from .shared import point_has_streetview, GeoPointSource, CachedGeoPointSource, ExhaustedSourceError, GeoPointSourceGroup
 
 
 RSV_URL = "https://randomstreetview.com/data"
 RSV_URL = "https://randomstreetview.com/data"
+VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw", 
+                   "br", "bg", "kh", "ca", "cl", "hr", "co", 
+                   "cz", "dk", "ae", "ee", "fi", "fr", "de", 
+                   "gr", "hu", "hk", "is", "id", "ie", "it", 
+                   "il", "jp", "lv", "lt", "my", "mx", "nl", 
+                   "nz", "no", "pe", "pl", "pt", "ro", "ru", 
+                   "sg", "sk", "si", "za", "kr", "es", "sz", 
+                   "se", "ch", "tw", "th", "ua", "gb", "us")
 
 
 
 
 def call_random_street_view(country_lock=None):
 def call_random_street_view(country_lock=None):
@@ -13,8 +23,10 @@ def call_random_street_view(country_lock=None):
 
 
     This function calls the streetview metadata endpoint - there is no quota consumed.
     This function calls the streetview metadata endpoint - there is no quota consumed.
     """
     """
+    if country_lock is None:
+        country_lock = random.choice(VALID_COUNTRIES)
     try:
     try:
-        rsv_js = requests.post(RSV_URL, data={"country": country_lock.lower() if country_lock is not None else "all"}).json()
+        rsv_js = requests.post(RSV_URL, data={"country": country_lock.lower()}).json()
     except:
     except:
         return []
         return []
 
 
@@ -22,7 +34,7 @@ def call_random_street_view(country_lock=None):
         return []
         return []
     
     
     return [
     return [
-        (point["lat"], point["lng"])
+        (country_lock, point["lat"], point["lng"])
         for point in rsv_js["locations"]
         for point in rsv_js["locations"]
         if point_has_streetview(point["lat"], point["lng"])
         if point_has_streetview(point["lat"], point["lng"])
     ]
     ]
@@ -48,14 +60,6 @@ class RSVPointSource(GeoPointSource):
 
 
 
 
 WORLD_SOURCE = CachedGeoPointSource(RSVPointSource(), 10)
 WORLD_SOURCE = CachedGeoPointSource(RSVPointSource(), 10)
-VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw", 
-                   "br", "bg", "kh", "ca", "cl", "hr", "co", 
-                   "cz", "dk", "ae", "ee", "fi", "fr", "de", 
-                   "gr", "hu", "hk", "is", "id", "ie", "it", 
-                   "il", "jp", "lv", "lt", "my", "mx", "nl", 
-                   "nz", "no", "pe", "pl", "pt", "ro", "ru", 
-                   "sg", "sk", "si", "za", "kr", "es", "sz", 
-                   "se", "ch", "tw", "th", "ua", "gb", "us")
 COUNTRY_SOURCES = {
 COUNTRY_SOURCES = {
     "us": CachedGeoPointSource(RSVPointSource("us"), 10),   # cache US specifically since it is commonly used
     "us": CachedGeoPointSource(RSVPointSource("us"), 10),   # cache US specifically since it is commonly used
     **{ k: RSVPointSource(k) for k in VALID_COUNTRIES if k not in ("us",) }
     **{ k: RSVPointSource(k) for k in VALID_COUNTRIES if k not in ("us",) }

+ 8 - 7
server/app/point_gen/shared.py

@@ -40,9 +40,10 @@ class GeoPointSource:
         """
         """
         raise NotImplemented("Must be implemented by subclasses")
         raise NotImplemented("Must be implemented by subclasses")
 
 
-    def get_points(self, n: int) -> List[Tuple[float, float]]:
+    def get_points(self, n: int) -> List[Tuple[str, float, float]]:
         """
         """
-        Return a list of at least n valid geo points, as (latitude, longitude) pairs.
+        Return a list of at least n valid geo points, as 
+        (2 character country code, latitude, longitude) tuples.
         In the event that the GeoPointSource cannot reasonably supply enough points,
         In the event that the GeoPointSource cannot reasonably supply enough points,
         most likely due to time constraints, it should raise an ExhaustedSourceError.
         most likely due to time constraints, it should raise an ExhaustedSourceError.
         """
         """
@@ -86,7 +87,7 @@ class CachedGeoPointSource(GeoPointSource):
                 self.restock(n=diff)
                 self.restock(n=diff)
             logger.info(f"Finished restocking {type(self).__name__}")
             logger.info(f"Finished restocking {type(self).__name__}")
 
 
-    def get_points(self, n: int) -> List[Tuple[float, float]]:
+    def get_points(self, n: int) -> List[Tuple[str, float, float]]:
         """
         """
         Pull n points from the current stock.
         Pull n points from the current stock.
         It is recommended to call CachedGeoPointSource.restock after this, to ensure 
         It is recommended to call CachedGeoPointSource.restock after this, to ensure 
@@ -125,10 +126,10 @@ class GeoPointSourceGroup:
         if isinstance(src, CachedGeoPointSource):
         if isinstance(src, CachedGeoPointSource):
             src.restock()
             src.restock()
 
 
-    def get_points_from(self, n: int, key: Union[str, None] = None) -> List[Tuple[float, float]]:
+    def get_points_from(self, n: int, key: Union[str, None] = None) -> List[Tuple[str, float, float]]:
         """
         """
-        Return a list of at least n valid geo points, as (latitude, longitude) pairs,
-        for a given key. If no key is provided, or no matching GeoPointSource is found,
-        the default GeoPointSource will be used.
+        Return a list of at least n valid geo points, for a given key. If no key 
+        is provided, or no matching GeoPointSource is found, the default 
+        GeoPointSource will be used.
         """
         """
         return self.sources.get(key, self.default).get_points(n)
         return self.sources.get(key, self.default).get_points(n)

+ 4 - 2
server/app/point_gen/urban_centers.py

@@ -19,6 +19,8 @@ with open("./data/urban-centers.csv") as infile:
         _found_countries.add(code)
         _found_countries.add(code)
         _urban_center_count += 1
         _urban_center_count += 1
 logger.info(f"Read {_urban_center_count} urban centers from {len(_found_countries)} countries.")
 logger.info(f"Read {_urban_center_count} urban centers from {len(_found_countries)} countries.")
+VALID_COUNTRIES = tuple(_found_countries)
+
 
 
 def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_km=25):
 def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_km=25):
     """
     """
@@ -33,6 +35,7 @@ def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_km=25)
     This function calls the streetview metadata endpoint - there is no quota consumed.
     This function calls the streetview metadata endpoint - there is no quota consumed.
     """
     """
 
 
+    country_lock = country_lock.lower()
     cities = URBAN_CENTERS[country_lock]
     cities = URBAN_CENTERS[country_lock]
     src = random.sample(cities, k=min(city_retries, len(cities)))
     src = random.sample(cities, k=min(city_retries, len(cities)))
 
 
@@ -59,7 +62,7 @@ def urban_coord(country_lock, city_retries=10, point_retries=10, max_dist_km=25)
             pt_lng = math.degrees(pt_lng_rad)
             pt_lng = math.degrees(pt_lng_rad)
             if point_has_streetview(pt_lat, pt_lng):
             if point_has_streetview(pt_lat, pt_lng):
                 logger.info("Point found!")
                 logger.info("Point found!")
-                return (pt_lat, pt_lng)
+                return (country_lock, pt_lat, pt_lng)
 
 
 
 
 class WorldUrbanPointSource(GeoPointSource):
 class WorldUrbanPointSource(GeoPointSource):
@@ -123,7 +126,6 @@ class CountryUrbanSourceDict(dict):
 
 
 
 
 WORLD_SOURCE = CachedGeoPointSource(WorldUrbanPointSource(), 20)
 WORLD_SOURCE = CachedGeoPointSource(WorldUrbanPointSource(), 20)
-VALID_COUNTRIES = tuple(_found_countries)
 COUNTRY_SOURCES = CountryUrbanSourceDict()
 COUNTRY_SOURCES = CountryUrbanSourceDict()
 COUNTRY_SOURCES["us"] = CachedGeoPointSource(CountryUrbanPointSource("us"), 20) # cache US
 COUNTRY_SOURCES["us"] = CachedGeoPointSource(CountryUrbanPointSource("us"), 20) # cache US
 SOURCE_GROUP = GeoPointSourceGroup(COUNTRY_SOURCES, WORLD_SOURCE)
 SOURCE_GROUP = GeoPointSourceGroup(COUNTRY_SOURCES, WORLD_SOURCE)

+ 6 - 2
server/app/schemas.py

@@ -5,6 +5,9 @@ from fastapi_camelcase import CamelModel
 from pydantic import conint, confloat, constr
 from pydantic import conint, confloat, constr
 
 
 
 
+CountryCode = constr(to_lower=True, min_length=2, max_length=2)
+
+
 class GenMethodEnum(str, Enum):
 class GenMethodEnum(str, Enum):
     # map_crunch = "MAPCRUNCH"
     # map_crunch = "MAPCRUNCH"
     rsv = "RANDOMSTREETVIEW"
     rsv = "RANDOMSTREETVIEW"
@@ -21,7 +24,7 @@ class RuleSetEnum(str, Enum):
 class GameConfig(CamelModel):
 class GameConfig(CamelModel):
     timer: conint(gt=0)
     timer: conint(gt=0)
     rounds: conint(gt=0)
     rounds: conint(gt=0)
-    country_lock: Union[constr(to_lower=True, min_length=2, max_length=2), None] = None
+    country_lock: Union[CountryCode, None] = None
     generation_method: GenMethodEnum = GenMethodEnum.rsv
     generation_method: GenMethodEnum = GenMethodEnum.rsv
     rule_set: RuleSetEnum = RuleSetEnum.normal
     rule_set: RuleSetEnum = RuleSetEnum.normal
 
 
@@ -33,6 +36,7 @@ class Guess(CamelModel):
     lat: confloat(ge=-90.0, le=90.0)
     lat: confloat(ge=-90.0, le=90.0)
     lng: confloat(ge=-180.0, le=180.0)
     lng: confloat(ge=-180.0, le=180.0)
     time_remaining: int
     time_remaining: int
+    country: Union[CountryCode, None] = None
 
 
 
 
 class CacheInfo(CamelModel):
 class CacheInfo(CamelModel):
@@ -42,4 +46,4 @@ class CacheInfo(CamelModel):
 
 
 class GeneratorInfo(CamelModel):
 class GeneratorInfo(CamelModel):
     generation_method: GenMethodEnum
     generation_method: GenMethodEnum
-    country_locks: List[constr(to_lower=True, min_length=2, max_length=2)]
+    country_locks: List[CountryCode]