Browse Source

Make restocking granular again

Kirk Trombley 4 years ago
parent
commit
74fb4b7048
3 changed files with 10 additions and 11 deletions
  1. 1 1
      server/app/api/game.py
  2. 3 3
      server/app/point_gen/__init__.py
  3. 6 7
      server/app/point_gen/shared.py

+ 1 - 1
server/app/api/game.py

@@ -34,7 +34,7 @@ def create_game(config: GameConfig, bg: BackgroundTasks, db: Session = Depends(g
         # TODO might be worth logging something useful here eventually
         raise HTTPException(status_code=501, detail="Sufficient points could not be generated quickly enough")
     game_id = queries.create_game(db, config, coords)
-    bg.add_task(restock_source, config.generation_method)
+    bg.add_task(restock_source, config)
     return { "gameId": game_id }
 
 

+ 3 - 3
server/app/point_gen/__init__.py

@@ -17,11 +17,11 @@ def generate_points(config: GameConfig) -> List[Tuple[float, float]]:
     return source_groups[config.generation_method].get_points_from(config.rounds, config.country_lock)
 
 
-def restock_source(generation_method: GenMethodEnum):
+def restock_source(config: GameConfig):
     """
-    Restock any caches associated with the generation method.
+    Restock any caches associated with the GameConfig.
     """
-    source_groups[generation_method].restock_all()
+    source_groups[config.generation_method].restock(config.country_lock)
 
 
 def get_cache_info() -> List[CacheInfo]:

+ 6 - 7
server/app/point_gen/shared.py

@@ -112,16 +112,15 @@ class GeoPointSourceGroup:
     def __init__(self, sources: Dict[str, GeoPointSource], default: GeoPointSource):
         self.sources = sources
         self.default = default
-        self.cached = [s for s in sources.values() if isinstance(s, CachedGeoPointSource)]
-        if isinstance(default, CachedGeoPointSource):
-            self.cached.append(default)
 
-    def restock_all(self):
+    def restock(self, key: Union[str, None] = None):
         """
-        Restock any and all CachedGeoPointSources managed by this group.
+        Restock a CachedGeoPointSources managed by this group.
+        If the targeted GeoPointSource is uncached, this method does nothing.
         """
-        for s in self.cached:
-            s.restock()
+        src = self.sources.get(key, self.default)
+        if isinstance(src, CachedGeoPointSource):
+            src.restock()
 
     def get_points_from(self, n: int, key: Union[str, None] = None) -> List[Tuple[float, float]]:
         """