Ver Fonte

Fix RSV world mode bug

Kirk Trombley há 4 anos atrás
pai
commit
8e359f6447
1 ficheiros alterados com 27 adições e 15 exclusões
  1. 27 15
      server/app/point_gen/random_street_view.py

+ 27 - 15
server/app/point_gen/random_street_view.py

@@ -15,7 +15,7 @@ VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw",
                    "se", "ch", "tw", "th", "ua", "gb", "us")
 
 
-def call_random_street_view(country_lock=None):
+def call_random_street_view(country_lock):
     """
     Returns an array of (some number of) tuples, each being (latitude, longitude).
     All points will be valid streetview coordinates. There is no guarantee as to the
@@ -23,11 +23,6 @@ def call_random_street_view(country_lock=None):
 
     This function calls the streetview metadata endpoint - there is no quota consumed.
     """
-    if country_lock is None:
-        # TODO this is WRONG - this makes them all from one country!
-        # this whole logic needs to be tweaked - maybe just cache points by country?
-        # a world game should pick a random country N times
-        country_lock = random.choice(VALID_COUNTRIES)
     try:
         rsv_js = requests.post(RSV_URL, data={"country": country_lock.lower()}).json()
     except:
@@ -43,13 +38,13 @@ def call_random_street_view(country_lock=None):
     ]
 
 
-class RSVPointSource(GeoPointSource):
-    def __init__(self, country_lock=None, max_attempts=100):
+class RSVCountryPointSource(GeoPointSource):
+    def __init__(self, country_lock, max_attempts=100):
         self.country_lock = country_lock
         self.max_attempts = max_attempts
 
     def get_name(self):
-        return f"RSV-{self.country_lock or 'all'}"
+        return f"RSV-{self.country_lock}"
 
     def get_points(self, n):
         attempts = 0
@@ -62,9 +57,26 @@ class RSVPointSource(GeoPointSource):
         return points
 
 
-WORLD_SOURCE = CachedGeoPointSource(RSVPointSource(), 10)
-COUNTRY_SOURCES = {
-    "us": CachedGeoPointSource(RSVPointSource("us"), 10),   # cache US specifically since it is commonly used
-    **{ k: RSVPointSource(k) for k in VALID_COUNTRIES if k not in ("us",) }
-}
-SOURCE_GROUP = GeoPointSourceGroup(COUNTRY_SOURCES, WORLD_SOURCE)
+COUNTRY_SOURCES = { k: CachedGeoPointSource(RSVCountryPointSource(k), 10) for k in VALID_COUNTRIES }
+
+
+class RSVWorldPointSource(GeoPointSource):
+    def __init__(self, max_attempts=10):
+        self.max_attempts = max_attempts
+
+    def get_name(self):
+        return "RSV-global"
+
+    def get_points(self, n):
+        attempts = 0
+        points = []
+        while len(points) < n:
+            if attempts > self.max_attempts:
+                raise ExhaustedSourceError(points)
+            points.extend(COUNTRY_SOURCES[random.choice(VALID_COUNTRIES)].get_points(1))
+            attempts += 1
+        return points
+
+
+WORLD_SOURCE = RSVWorldPointSource()
+SOURCE_GROUP = GeoPointSourceGroup(COUNTRY_SOURCES, WORLD_SOURCE)