Browse Source

parallelize analyzer and refactor scratch work

Kirk Trombley 2 years ago
parent
commit
316a588abe
1 changed files with 82 additions and 24 deletions
  1. 82 24
      tools/analyze.py

+ 82 - 24
tools/analyze.py

@@ -1,7 +1,12 @@
 import math
+import os
+import time
+from collections import defaultdict
+from concurrent.futures import ProcessPoolExecutor
 from pathlib import Path
 from dataclasses import dataclass, asdict
 from itertools import combinations
+from typing import Callable
 
 import numpy as np
 from PIL import Image
@@ -129,34 +134,87 @@ def get_srgb_pixels(img: Image.Image) -> np.array:
   return np.array(rgb)
 
 
+def search_files(image_dir: str) -> dict[str, list[str]]:
+  files = defaultdict(list)
+  for image_filename in os.listdir(image_dir):
+    form_name = image_filename.rsplit("-", maxsplit=1)[0]
+    files[form_name].append(Path(image_dir, image_filename))
+  return files
+
+
+class Ingester:
+  def __init__(self, dex: dict) -> None:
+    self.lookup = {
+      form["name"]: {
+        "num": pkmn["num"],
+        "species": pkmn["species"],
+        **form,
+      }
+      for pkmn in dex.values()
+      for form in pkmn["forms"]
+    }
+
+  def __call__(self, args: tuple[str, list[str]]) -> tuple[Stats, list[Stats]]:
+    form, filenames = args
+    print(f"Ingesting {form}...")
+    start_time = time.time()
+    all_pixels = np.concatenate([
+      get_srgb_pixels(Image.open(fn)) for fn in filenames
+    ])
+    oklab = srgb2oklab(all_pixels)
+    conv_time = time.time()
+    total = calc_statistics(oklab)
+    calc_time = time.time()
+    clusters = [asdict(c) for c in calc_clusters(oklab, seed=seed)]
+    cluster_time = time.time()
+    print(
+      f"Completed {form}: ",
+      f"{(cluster_time - start_time):.02f}s in total,",
+      f"{(cluster_time - calc_time):.02f}s on clustering,",
+      f"{(calc_time - conv_time):.02f}s on total calcs,",
+      f"{(conv_time - start_time):.02f}s on read and conversion,",
+      f"centroid {total.hex} and {len(clusters)} clusters"
+    )
+    return {
+      **self.lookup[form],
+      "total": asdict(total),
+      "clusters": clusters,
+    }
+
+
+def output_db(results: list[dict], db_file: str):
+  with open(db_file, "w") as output:
+    output.write("const database = [\n")
+    for entry in results:
+      output.write("  ")
+      output.write(json.dumps(entry))
+      output.write(",\n")
+    output.write("]\n")
+
+
 if __name__ == "__main__":
   from sys import argv
   dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
   image_dir = argv[2] if len(argv) > 2 else "images"
   seed = int(argv[3]) if len(argv) > 3 else 230308
+  db_file = argv[4] if len(argv) > 4 else "data/latest.db"
 
-  import os
-  from collections import defaultdict
-
-  to_process = defaultdict(list)
-  for image_filename in os.listdir(image_dir):
-    form_name = image_filename.rsplit("-", maxsplit=1)[0]
-    to_process[form_name].append(Path(image_dir, image_filename))
-
-  # TODO multiproc
-  database = []
-  for form, image_files in to_process.items():
-    all_pixels = np.concatenate([
-      get_srgb_pixels(Image.open(fn)) for fn in image_files
-    ])
-    oklab = srgb2oklab(all_pixels)
-    database.append({
-      "name": form,
-      # TODO also get dex info - species, color, etc.
-      "total": asdict(calc_statistics(oklab)),
-      "clusters": [asdict(c) for c in calc_clusters(oklab, seed=seed)],
-    })
-
-  # TODO real output
   import json
-  print(json.dumps(database, indent=2))
+  with open(dex_file) as infile:
+    dex = json.load(infile)
+
+  ingest = Ingester(dex)
+  to_process = search_files(image_dir)
+
+  start = time.time()
+  with ProcessPoolExecutor(4) as pool:
+    results = list(pool.map(ingest, to_process.items()))
+  end = time.time()
+  print(
+    f"Finished ingest of {len(to_process)} forms",
+    f"and {sum(len(fns) for fns in to_process.values())} files",
+    f"in {(end - start):.2f}s"
+  )
+
+  output_db(sorted(results, key=lambda e: (e["num"], e["name"])), db_file)
+  print(f"Output {len(results)} entries to {db_file}")