Browse Source

Refactor point gen logic for clarity

Kirk Trombley 4 years ago
parent
commit
5947290fa4
1 changed files with 47 additions and 43 deletions
  1. 47 43
      server/app/point_gen/__init__.py

+ 47 - 43
server/app/point_gen/__init__.py

@@ -39,14 +39,31 @@ class PointStore:
                  rsv_country_retries: int = 5,
                  urban_country_pool_size: int = 30,
                  urban_country_retries: int = 30,
-                 urban_city_retries: int = 50):
+                 urban_city_retries: int = 50,
+                 urban_city_retries_per_random_country: int = 10):
         self.cache_targets = cache_targets
         self.rsv_country_retries = rsv_country_retries
         self.urban_country_pool_size = urban_country_pool_size
         self.urban_country_retries = urban_country_retries
         self.urban_city_retries = urban_city_retries
+        self.urban_city_retries_per_random_country = urban_city_retries_per_random_country
         self.store = collections.defaultdict(collections.deque)
 
+    async def _gen_rsv_point(self, country: CountryCode):
+        # RSV point function returns a collection of points, which should be cached
+        points = await call_random_street_view(country)
+        if len(points) > 0:
+            point = points.pop()
+            self.store[(GenMethodEnum.rsv, country)].extend(points)
+            return point
+
+    async def _gen_urban_point(self, countries: List[CountryCode], city_retries: int):
+        for country in countries:
+            logger.info(f"Selecting urban centers from {country}")
+            pt = await urban_coord(country, city_retries=city_retries)
+            if pt is not None:
+                return pt
+
     async def get_point(self, generator: GenMethodEnum, country: Union[CountryCode, None], force_generate: bool = False) -> Tuple[str, float, float]:
         if country is None:
             # generating points across the whole world
@@ -55,53 +72,39 @@ class PointStore:
                 for _ in range(self.rsv_country_retries):
                     # try a few countries before giving up, just in case one has no data
                     country = random.choice(RSV_COUNTRIES)
-                    # RSV point function returns a collection of points, which should be cached
-                    points = await call_random_street_view(country)
-                    if len(points) > 0:
-                        point = points.pop()
-                        self.store[(generator, country)].extend(points)
+                    point = await self._gen_rsv_point(country)
+                    if point is not None:
                         return point
-                else:
-                    raise ExhaustedSourceError
             elif generator == GenMethodEnum.urban:
                 # try many countries since finding an urban center point is harder
                 countries = random.sample(URBAN_COUNTRIES, k=min(self.urban_country_pool_size, len(URBAN_COUNTRIES)))
-                for country in countries:
-                    logger.info(f"Selecting urban centers from {country}")
-                    pt = await urban_coord(country)
-                    if pt is not None:
-                        return pt
-                else:
-                    raise ExhaustedSourceError
-            else:
-                raise ExhaustedSourceError
-        else:
-            # generating points for a specific country
-            # if we already have a point ready, just return it immediately
-            if not force_generate:
-                stock = self.store[(generator, country)]
-                if len(stock) > 0:
-                    return stock.popleft()
-            
-            # otherwise, need to actually generate a new point
-            if generator == GenMethodEnum.rsv:
-                # RSV point function returns a collection of points, which should be cached
-                points = await call_random_street_view(country)
-                if len(points) == 0:
-                    raise ExhaustedSourceError
-                point = points.pop()
-                self.store[(generator, country)].extend(points)
+                point = await self._gen_urban_point(countries, self.urban_city_retries_per_random_country)
+                if point is not None:
+                    return point
+
+            # if nothing could be done - inform the caller
+            raise ExhaustedSourceError
+
+        # generating points for a specific country
+
+        # if we already have a point ready, just return it immediately
+        if not force_generate:
+            stock = self.store[(generator, country)]
+            if len(stock) > 0:
+                return stock.popleft()
+        
+        # otherwise, need to actually generate a new point
+        if generator == GenMethodEnum.rsv:
+            point = await self._gen_rsv_point(country)
+            if point is not None:
                 return point
-            elif generator == GenMethodEnum.urban:
-                for i in range(self.urban_country_retries):
-                    logger.info(f"Attempt #{i + 1} to select urban centers from {country}")
-                    pt = await urban_coord(country, city_retries=self.urban_city_retries)
-                    if pt is not None:
-                        return pt
-                else:
-                    raise ExhaustedSourceError
-            else:
-                raise ExhaustedSourceError
+        elif generator == GenMethodEnum.urban:
+            point = await self._gen_urban_point((country for _ in range(self.urban_country_retries)), self.urban_city_retries)
+            if point is not None:
+                return point
+
+        # finally, if all that fails, just inform the caller
+        raise ExhaustedSourceError
 
     async def get_points(self, config: GameConfig) -> List[Tuple[str, float, float]]:
         """
@@ -117,6 +120,7 @@ class PointStore:
             point_tasks = [self.get_point(config.generation_method, config.country_lock) for _ in range(config.rounds)]
             gathered = asyncio.gather(*point_tasks)
             return await asyncio.wait_for(gathered, 60)
+            # TODO - it would be nice to keep partially generated sets around if there's a timeout or exhaustion
         except asyncio.TimeoutError:
             raise ExhaustedSourceError