__init__.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import asyncio
  2. import collections
  3. import random
  4. from typing import List, Tuple, Dict, Union
  5. from .random_street_view import call_random_street_view, VALID_COUNTRIES as RSV_COUNTRIES
  6. from .urban_centers import urban_coord_unlocked, urban_coord_ensured, VALID_COUNTRIES as URBAN_COUNTRIES
  7. from .shared import ExhaustedSourceError, aiohttp_client
  8. from ..schemas import GameConfig, GenMethodEnum, CountryCode, CacheInfo, GeneratorInfo
  9. generator_info = [
  10. GeneratorInfo(
  11. generation_method=GenMethodEnum.rsv,
  12. country_locks=RSV_COUNTRIES
  13. ),
  14. GeneratorInfo(
  15. generation_method=GenMethodEnum.urban,
  16. country_locks=URBAN_COUNTRIES
  17. ),
  18. ]
  19. cache_names = {
  20. GenMethodEnum.rsv: "RSV",
  21. GenMethodEnum.urban: "Urban",
  22. }
  23. class PointStore:
  24. def __init__(self, cache_targets: Dict[Tuple[GenMethodEnum, CountryCode], int]):
  25. self.cache_targets = cache_targets
  26. self.store = collections.defaultdict(collections.deque)
  27. async def generate_point(self, generator: GenMethodEnum, country: Union[CountryCode, None]) -> Tuple[str, float, float]:
  28. if generator == GenMethodEnum.rsv:
  29. # RSV point functions return a collection of points, which should be cached
  30. point, *points = await call_random_street_view(country)
  31. # use the country on the point - since country itself might be None
  32. self.store[(generator, point[0])].extend(points)
  33. return point
  34. elif generator == GenMethodEnum.urban:
  35. # urban center point functions only return a single point
  36. if country is None:
  37. return await urban_coord_unlocked()
  38. return await urban_coord_ensured(country, city_retries=50)
  39. else:
  40. raise ExhaustedSourceError
  41. async def get_point(self, generator: GenMethodEnum, country: Union[CountryCode, None]) -> Tuple[str, float, float]:
  42. if country is not None:
  43. # if we already have a point ready, just return it immediately
  44. # to avoid bias, we only do this in country-locking mode
  45. stock = self.store[(generator, country)]
  46. if len(stock) > 0:
  47. return stock.popleft()
  48. return await self.generate_point(generator, country)
  49. async def get_points(self, config: GameConfig) -> List[Tuple[str, float, float]]:
  50. """
  51. Provide points according to the GameConfig.
  52. Return a list of at least n valid geo points, as
  53. (2 character country code, latitude, longitude) tuples.
  54. In the event that the configured source cannot reasonably supply enough points,
  55. most likely due to time constraints, this will raise an ExhaustedSourceError.
  56. """
  57. return await asyncio.gather(*[self.get_point(config.generation_method, config.country_lock) for _ in range(config.rounds)])
  58. def get_cache_info(self) -> List[CacheInfo]:
  59. """
  60. Get CacheInfo for all caches.
  61. """
  62. return [CacheInfo(cache_name=f"{cache_names[g]}-{c}", size=len(ps)) for (g, c), ps in self.store.items()]
  63. async def restock_source(self, config: GameConfig):
  64. """
  65. Restock any caches associated with the GameConfig.
  66. """
  67. if config.country_lock is None:
  68. return
  69. key = (config.generation_method, config.country_lock)
  70. target = self.cache_targets.get(key, 0)
  71. stock = self.store[key]
  72. while len(stock) < target:
  73. stock.append(await self.generate_point(*key))
  74. points = PointStore({
  75. (GenMethodEnum.urban, "us"): 10,
  76. })