Kirk Trombley 4 жил өмнө
parent
commit
8e359f6447

+ 27 - 15
server/app/point_gen/random_street_view.py

@@ -15,7 +15,7 @@ VALID_COUNTRIES = ("ad", "au", "ar", "bd", "be", "bt", "bw",
                    "se", "ch", "tw", "th", "ua", "gb", "us")
                    "se", "ch", "tw", "th", "ua", "gb", "us")
 
 
 
 
-def call_random_street_view(country_lock=None):
+def call_random_street_view(country_lock):
     """
     """
     Returns an array of (some number of) tuples, each being (latitude, longitude).
     Returns an array of (some number of) tuples, each being (latitude, longitude).
     All points will be valid streetview coordinates. There is no guarantee as to the
     All points will be valid streetview coordinates. There is no guarantee as to the
@@ -23,11 +23,6 @@ def call_random_street_view(country_lock=None):
 
 
     This function calls the streetview metadata endpoint - there is no quota consumed.
     This function calls the streetview metadata endpoint - there is no quota consumed.
     """
     """
-    if country_lock is None:
-        # TODO this is WRONG - this makes them all from one country!
-        # this whole logic needs to be tweaked - maybe just cache points by country?
-        # a world game should pick a random country N times
-        country_lock = random.choice(VALID_COUNTRIES)
     try:
     try:
         rsv_js = requests.post(RSV_URL, data={"country": country_lock.lower()}).json()
         rsv_js = requests.post(RSV_URL, data={"country": country_lock.lower()}).json()
     except:
     except:
@@ -43,13 +38,13 @@ def call_random_street_view(country_lock=None):
     ]
     ]
 
 
 
 
-class RSVPointSource(GeoPointSource):
-    def __init__(self, country_lock=None, max_attempts=100):
+class RSVCountryPointSource(GeoPointSource):
+    def __init__(self, country_lock, max_attempts=100):
         self.country_lock = country_lock
         self.country_lock = country_lock
         self.max_attempts = max_attempts
         self.max_attempts = max_attempts
 
 
     def get_name(self):
     def get_name(self):
-        return f"RSV-{self.country_lock or 'all'}"
+        return f"RSV-{self.country_lock}"
 
 
     def get_points(self, n):
     def get_points(self, n):
         attempts = 0
         attempts = 0
@@ -62,9 +57,26 @@ class RSVPointSource(GeoPointSource):
         return points
         return points
 
 
 
 
-WORLD_SOURCE = CachedGeoPointSource(RSVPointSource(), 10)
-COUNTRY_SOURCES = {
-    "us": CachedGeoPointSource(RSVPointSource("us"), 10),   # cache US specifically since it is commonly used
-    **{ k: RSVPointSource(k) for k in VALID_COUNTRIES if k not in ("us",) }
-}
-SOURCE_GROUP = GeoPointSourceGroup(COUNTRY_SOURCES, WORLD_SOURCE)
+COUNTRY_SOURCES = { k: CachedGeoPointSource(RSVCountryPointSource(k), 10) for k in VALID_COUNTRIES }
+
+
+class RSVWorldPointSource(GeoPointSource):
+    def __init__(self, max_attempts=10):
+        self.max_attempts = max_attempts
+
+    def get_name(self):
+        return "RSV-global"
+
+    def get_points(self, n):
+        attempts = 0
+        points = []
+        while len(points) < n:
+            if attempts > self.max_attempts:
+                raise ExhaustedSourceError(points)
+            points.extend(COUNTRY_SOURCES[random.choice(VALID_COUNTRIES)].get_points(1))
+            attempts += 1
+        return points
+
+
+WORLD_SOURCE = RSVWorldPointSource()
+SOURCE_GROUP = GeoPointSourceGroup(COUNTRY_SOURCES, WORLD_SOURCE)