__init__.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import asyncio
  2. import collections
  3. import logging
  4. import random
  5. from itertools import groupby
  6. from typing import List, Tuple, Dict, Union
  7. from .random_street_view import call_random_street_view, VALID_COUNTRIES as RSV_COUNTRIES
  8. from .urban_centers import urban_coord, VALID_COUNTRIES as URBAN_COUNTRIES
  9. from .shared import aiohttp_client, reverse_geocode
  10. from ..schemas import GameConfig, GenMethodEnum, CountryCode, CacheInfo, GeneratorInfo
  11. logger = logging.getLogger(__name__)
  12. generator_info = [
  13. GeneratorInfo(
  14. generation_method=GenMethodEnum.rsv,
  15. country_locks=RSV_COUNTRIES
  16. ),
  17. GeneratorInfo(
  18. generation_method=GenMethodEnum.urban,
  19. country_locks=URBAN_COUNTRIES
  20. ),
  21. ]
  22. cache_names = {
  23. GenMethodEnum.rsv: "RSV",
  24. GenMethodEnum.urban: "Urban",
  25. }
  26. class ExhaustedSourceError(Exception):
  27. pass
  28. DIFFICULTY_1 = ["ee"] # TODO actually fill these out
  29. DIFFICULTY_2 = ["ee"]
  30. DIFFICULTY_3 = ["ee"]
  31. DIFFICULTY_4 = ["ee"]
  32. DIFFICULTY_5 = ["ee"]
  33. DIFFICULTY_X = ["ee"] # TODO all islands
  34. DIFFICULTY_TIER_ORDER = (
  35. DIFFICULTY_1, DIFFICULTY_2, DIFFICULTY_3, DIFFICULTY_4,
  36. DIFFICULTY_5,
  37. DIFFICULTY_4, DIFFICULTY_3, DIFFICULTY_2, DIFFICULTY_1,
  38. DIFFICULTY_X,
  39. )
  40. class PointStore:
  41. def __init__(self,
  42. cache_targets: Dict[Tuple[GenMethodEnum, CountryCode], int],
  43. rsv_country_retries: int = 5,
  44. urban_country_pool_size: int = 30,
  45. urban_country_retries: int = 30,
  46. urban_city_retries: int = 50,
  47. urban_city_retries_per_random_country: int = 10):
  48. self.cache_targets = cache_targets
  49. self.rsv_country_retries = rsv_country_retries
  50. self.urban_country_pool_size = urban_country_pool_size
  51. self.urban_country_retries = urban_country_retries
  52. self.urban_city_retries = urban_city_retries
  53. self.urban_city_retries_per_random_country = urban_city_retries_per_random_country
  54. self.store = collections.defaultdict(collections.deque)
  55. async def _gen_rsv_point(self, country: CountryCode):
  56. # RSV point function returns a collection of points, which should be cached
  57. for actual_country, points in groupby(await call_random_street_view(country), key=lambda p: p[0]):
  58. # but these points need to be cached according to the actual reverse geocoded country they are in
  59. self.store[(GenMethodEnum.rsv, actual_country)].extend(points)
  60. stock = self.store[(GenMethodEnum.rsv, country)]
  61. if len(stock) > 0:
  62. return stock.popleft()
  63. async def _gen_urban_point(self, countries: List[CountryCode], city_retries: int):
  64. for country in countries:
  65. logger.info(f"Selecting urban centers from {country}")
  66. pt = await urban_coord(country, city_retries=city_retries)
  67. if pt is not None:
  68. if pt[0] == country:
  69. return pt
  70. else:
  71. # TODO technically this is slightly wasted effort in rare edge cases
  72. self.store[(GenMethodEnum.urban, pt[0])].append(pt)
  73. # TODO I think all of this logic still gets stuck in the trap of generating points even when there's a stock in some edge cases
  74. # this needs a rewrite but I'm not doing that now
  75. async def get_point(self, generator: GenMethodEnum, country: Union[CountryCode, None], force_generate: bool = False) -> Tuple[str, float, float]:
  76. if country is None:
  77. # generating points across the whole world
  78. # for current generators, this means selecting a country at random
  79. if generator == GenMethodEnum.rsv:
  80. for _ in range(self.rsv_country_retries):
  81. # try a few countries before giving up, just in case one has no data
  82. country = random.choice(RSV_COUNTRIES)
  83. point = await self._gen_rsv_point(country)
  84. if point is not None:
  85. return point
  86. elif generator == GenMethodEnum.urban:
  87. # try many countries since finding an urban center point is harder
  88. countries = random.sample(URBAN_COUNTRIES, k=min(self.urban_country_pool_size, len(URBAN_COUNTRIES)))
  89. point = await self._gen_urban_point(countries, self.urban_city_retries_per_random_country)
  90. if point is not None:
  91. return point
  92. # if nothing could be done - inform the caller
  93. raise ExhaustedSourceError
  94. # generating points for a specific country
  95. # if we already have a point ready, just return it immediately
  96. if not force_generate:
  97. stock = self.store[(generator, country)]
  98. if len(stock) > 0:
  99. return stock.popleft()
  100. # otherwise, need to actually generate a new point
  101. if generator == GenMethodEnum.rsv:
  102. point = await self._gen_rsv_point(country)
  103. if point is not None:
  104. return point
  105. elif generator == GenMethodEnum.urban:
  106. point = await self._gen_urban_point((country for _ in range(self.urban_country_retries)), self.urban_city_retries)
  107. if point is not None:
  108. return point
  109. # finally, if all that fails, just inform the caller
  110. raise ExhaustedSourceError
  111. async def get_points(self, config: GameConfig) -> List[Tuple[str, float, float]]:
  112. """
  113. Provide points according to the GameConfig.
  114. Return a list of valid geo points, as
  115. (2 character country code, latitude, longitude) tuples.
  116. In the event that the configured source cannot reasonably supply enough points,
  117. most likely due to time constraints, this will raise an ExhaustedSourceError.
  118. """
  119. try:
  120. if config.generation_method == GenMethodEnum.diff_tiered:
  121. # in the case of using the "difficulty tiered" generator there is some special logic
  122. # assume that, in general, we want 10 points (4 normal rounds going up in difficulty, 1 max difficulty round, 4 normal going down, 1 nightmare tier)
  123. # if more are requested, it repeats. if less, it only goes that far.
  124. def make_point_task(tier):
  125. country_lock = random.choice(tier)
  126. if country_lock in random_street_view.VALID_COUNTRIES:
  127. return self.get_point(GenMethodEnum.rsv, country_lock)
  128. elif country_lock in urban_centers.VALID_COUNTRIES:
  129. return self.get_point(GenMethodEnum.urban, country_lock)
  130. else:
  131. raise ExhaustedSourceError
  132. point_tasks = [make_point_task(DIFFICULTY_TIER_ORDER[i % len(DIFFICULTY_TIER_ORDER)]) for i in range(config.rounds)]
  133. else:
  134. point_tasks = [self.get_point(config.generation_method, config.country_lock) for _ in range(config.rounds)]
  135. gathered = asyncio.gather(*point_tasks)
  136. return await asyncio.wait_for(gathered, 60)
  137. # TODO - it would be nice to keep partially generated sets around if there's a timeout or exhaustion
  138. except asyncio.TimeoutError:
  139. raise ExhaustedSourceError
  140. def get_cache_info(self) -> List[CacheInfo]:
  141. """
  142. Get CacheInfo for all caches.
  143. """
  144. return [CacheInfo(cache_name=f"{cache_names[g]}-{c}", size=len(ps)) for (g, c), ps in self.store.items()]
  145. async def _restock_source_impl(self, generator: GenMethodEnum, country: CountryCode):
  146. key = (generator, country)
  147. target = self.cache_targets.get(key, 0)
  148. stock = self.store[key]
  149. while len(stock) < target: # this check allows for RSV to do its multi-point restock
  150. stock.append(await self.get_point(*key, force_generate=True))
  151. async def restock_source(self, config: GameConfig):
  152. """
  153. Restock any caches associated with the GameConfig.
  154. """
  155. if config.country_lock is None:
  156. return
  157. try:
  158. await self._restock_source_impl(config.generation_method, config.country_lock)
  159. except ExhaustedSourceError:
  160. # if the cache can't be restocked, that is bad, but not fatal
  161. logger.exception(f"Failed to fully restock point cache for {config}")
  162. async def restock_all(self, timeout: Union[int, float, None] = None):
  163. """
  164. Restock all caches.
  165. """
  166. restock_tasks = [self._restock_source_impl(gen, cc) for (gen, cc) in self.cache_targets.keys()]
  167. gathered = asyncio.gather(*restock_tasks)
  168. try:
  169. await asyncio.wait_for(gathered, timeout)
  170. except (asyncio.TimeoutError, ExhaustedSourceError):
  171. # if this task times out, it's fine, as it's just intended to be a best effort
  172. logger.exception("Failed to fully restock a point cache!")
  173. points = PointStore({
  174. (GenMethodEnum.urban, "us"): 10,
  175. })