__init__.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import asyncio
  2. import collections
  3. import logging
  4. import random
  5. from typing import List, Tuple, Dict, Union
  6. from .random_street_view import call_random_street_view, VALID_COUNTRIES as RSV_COUNTRIES
  7. from .urban_centers import urban_coord, VALID_COUNTRIES as URBAN_COUNTRIES
  8. from .shared import aiohttp_client
  9. from ..schemas import GameConfig, GenMethodEnum, CountryCode, CacheInfo, GeneratorInfo
  10. logger = logging.getLogger(__name__)
  11. generator_info = [
  12. GeneratorInfo(
  13. generation_method=GenMethodEnum.rsv,
  14. country_locks=RSV_COUNTRIES
  15. ),
  16. GeneratorInfo(
  17. generation_method=GenMethodEnum.urban,
  18. country_locks=URBAN_COUNTRIES
  19. ),
  20. ]
  21. cache_names = {
  22. GenMethodEnum.rsv: "RSV",
  23. GenMethodEnum.urban: "Urban",
  24. }
  25. class ExhaustedSourceError(Exception):
  26. pass
  27. class PointStore:
  28. def __init__(self,
  29. cache_targets: Dict[Tuple[GenMethodEnum, CountryCode], int],
  30. rsv_country_retries: int = 5,
  31. urban_country_pool_size: int = 30,
  32. urban_country_retries: int = 30,
  33. urban_city_retries: int = 50):
  34. self.cache_targets = cache_targets
  35. self.rsv_country_retries = rsv_country_retries
  36. self.urban_country_pool_size = urban_country_pool_size
  37. self.urban_country_retries = urban_country_retries
  38. self.urban_city_retries = urban_city_retries
  39. self.store = collections.defaultdict(collections.deque)
  40. async def get_point(self, generator: GenMethodEnum, country: Union[CountryCode, None], force_generate: bool = False) -> Tuple[str, float, float]:
  41. if country is None:
  42. # generating points across the whole world
  43. # for current generators, this means selecting a country at random
  44. if generator == GenMethodEnum.rsv:
  45. for _ in range(self.rsv_country_retries):
  46. # try a few countries before giving up, just in case one has no data
  47. country = random.choice(RSV_COUNTRIES)
  48. # RSV point function returns a collection of points, which should be cached
  49. points = await call_random_street_view(country)
  50. if len(points) > 0:
  51. point = points.pop()
  52. self.store[(generator, country)].extend(points)
  53. return point
  54. else:
  55. raise ExhaustedSourceError
  56. elif generator == GenMethodEnum.urban:
  57. # try many countries since finding an urban center point is harder
  58. countries = random.sample(URBAN_COUNTRIES, k=min(self.urban_country_pool_size, len(URBAN_COUNTRIES)))
  59. for country in countries:
  60. logger.info(f"Selecting urban centers from {country}")
  61. pt = await urban_coord(country)
  62. if pt is not None:
  63. return pt
  64. else:
  65. raise ExhaustedSourceError
  66. else:
  67. raise ExhaustedSourceError
  68. else:
  69. # generating points for a specific country
  70. # if we already have a point ready, just return it immediately
  71. if not force_generate:
  72. stock = self.store[(generator, country)]
  73. if len(stock) > 0:
  74. return stock.popleft()
  75. # otherwise, need to actually generate a new point
  76. if generator == GenMethodEnum.rsv:
  77. # RSV point function returns a collection of points, which should be cached
  78. points = await call_random_street_view(country)
  79. if len(points) == 0:
  80. raise ExhaustedSourceError
  81. point = points.pop()
  82. self.store[(generator, country)].extend(points)
  83. return point
  84. elif generator == GenMethodEnum.urban:
  85. for i in range(self.urban_country_retries):
  86. logger.info(f"Attempt #{i + 1} to select urban centers from {country}")
  87. pt = await urban_coord(country, city_retries=self.urban_city_retries)
  88. if pt is not None:
  89. return pt
  90. else:
  91. raise ExhaustedSourceError
  92. else:
  93. raise ExhaustedSourceError
  94. async def get_points(self, config: GameConfig) -> List[Tuple[str, float, float]]:
  95. """
  96. Provide points according to the GameConfig.
  97. Return a list of valid geo points, as
  98. (2 character country code, latitude, longitude) tuples.
  99. In the event that the configured source cannot reasonably supply enough points,
  100. most likely due to time constraints, this will raise an ExhaustedSourceError.
  101. """
  102. try:
  103. point_tasks = [self.get_point(config.generation_method, config.country_lock) for _ in range(config.rounds)]
  104. gathered = asyncio.gather(*point_tasks)
  105. return await asyncio.wait_for(gathered, 60)
  106. except asyncio.TimeoutError:
  107. raise ExhaustedSourceError
  108. def get_cache_info(self) -> List[CacheInfo]:
  109. """
  110. Get CacheInfo for all caches.
  111. """
  112. return [CacheInfo(cache_name=f"{cache_names[g]}-{c}", size=len(ps)) for (g, c), ps in self.store.items()]
  113. async def _restock_source_impl(self, generator: GenMethodEnum, country: CountryCode):
  114. key = (generator, country)
  115. target = self.cache_targets.get(key, 0)
  116. stock = self.store[key]
  117. while len(stock) < target: # this check allows for RSV to do its multi-point restock
  118. stock.append(await self.get_point(*key, force_generate=True))
  119. async def restock_source(self, config: GameConfig):
  120. """
  121. Restock any caches associated with the GameConfig.
  122. """
  123. if config.country_lock is None:
  124. return
  125. try:
  126. await self._restock_source_impl(config.generation_method, config.country_lock)
  127. except ExhaustedSourceError:
  128. # if the cache can't be restocked, that is bad, but not fatal
  129. logger.exception(f"Failed to fully restock point cache for {config}")
  130. async def restock_all(self, timeout: Union[int, float, None] = None):
  131. """
  132. Restock all caches.
  133. """
  134. restock_tasks = [self._restock_source_impl(gen, cc) for (gen, cc) in self.cache_targets.keys()]
  135. gathered = asyncio.gather(*restock_tasks)
  136. try:
  137. await asyncio.wait_for(gathered, timeout)
  138. except (asyncio.TimeoutError, ExhaustedSourceError):
  139. # if this task times out, it's fine, as it's just intended to be a best effort
  140. logger.exception(f"Failed to fully restock point cache for {config}")
  141. points = PointStore({
  142. (GenMethodEnum.urban, "us"): 10,
  143. })