Kirk Trombley пре 2 година
родитељ
комит
c6cc18f88e
1 измењених фајлова са 84 додато и 35 уклоњено
  1. 84 35
      tools/analyze.py

+ 84 - 35
tools/analyze.py

@@ -47,7 +47,7 @@ LMS_TO_OKLAB = np.array([
 OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
 
 # round output to this many decimals
-OUTPUT_PRECISION = 6
+OUTPUT_PRECISION = 8
 
 
 def oklab2hex(pixel: np.array) -> str:
@@ -73,7 +73,7 @@ class Stats:
   hex: str
 
 
-def calc_statistics(pixels: np.array) -> Stats:
+def calc_statistics(pixels: np.array, output_precision: int) -> Stats:
   # Euclidean norm squared by summing squared components
   sqnorms = (pixels ** 2).sum(axis=1)
   # centroid, the arithmetic mean pixel of the image
@@ -91,17 +91,17 @@ def calc_statistics(pixels: np.array) -> Stats:
   # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean
   hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).sum(axis=0))
   return Stats(
-    centroid=list(np.round(centroid, OUTPUT_PRECISION)),
-    tilt=list(np.round(tilt, OUTPUT_PRECISION)),
-    variance=round(variance, OUTPUT_PRECISION),
-    chroma=round(chroma.mean(axis=0), OUTPUT_PRECISION),
-    hue=round(hue % (2 * math.pi), OUTPUT_PRECISION),
+    centroid=list(np.round(centroid, output_precision)),
+    tilt=list(np.round(tilt, output_precision)),
+    variance=round(variance, output_precision),
+    chroma=round(chroma.mean(axis=0), output_precision),
+    hue=round(hue % (2 * math.pi), output_precision),
     size=len(pixels),
     hex=oklab2hex(centroid),
   )
 
 
-def calc_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
+def calc_clusters(pixels: np.array, output_precision: int, cluster_attempts=5, seed=0) -> list[Stats]:
   means, labels = max(
     (
       # Try k = 2, 3, and 4, and try a few times for each
@@ -119,7 +119,7 @@ def calc_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
         .sum(axis=1)
         .mean(axis=0)
   )
-  return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
+  return [calc_statistics(pixels[labels == i], output_precision) for i in range(len(means))]
 
 
 def get_srgb_pixels(img: Image.Image) -> np.array:
@@ -143,7 +143,7 @@ def search_files(image_dir: str) -> dict[str, list[str]]:
 
 
 class Ingester:
-  def __init__(self, dex: dict) -> None:
+  def __init__(self, dex: dict, output_precision: int, seed: int) -> None:
     self.lookup = {
       form["name"]: {
         "num": pkmn["num"],
@@ -153,20 +153,30 @@ class Ingester:
       for pkmn in dex.values()
       for form in pkmn["forms"]
     }
+    self.seed = seed
+    self.output_precision = output_precision
 
-  def __call__(self, args: tuple[str, list[str]]) -> tuple[Stats, list[Stats]]:
+  def __call__(self, args: tuple[str, list[str]]) -> dict | Exception:
     form, filenames = args
     print(f"Ingesting {form}...")
     start_time = time.time()
-    all_pixels = np.concatenate([
-      get_srgb_pixels(Image.open(fn)) for fn in filenames
-    ])
-    oklab = srgb2oklab(all_pixels)
-    conv_time = time.time()
-    total = calc_statistics(oklab)
-    calc_time = time.time()
-    clusters = [asdict(c) for c in calc_clusters(oklab, seed=seed)]
-    cluster_time = time.time()
+    try:
+      all_pixels = np.concatenate([
+        get_srgb_pixels(Image.open(fn)) for fn in filenames
+      ])
+    except Exception as e:
+      print(f"Error loading images for {form}: {e}")
+      return e
+    try:
+      oklab = srgb2oklab(all_pixels)
+      conv_time = time.time()
+      total = calc_statistics(oklab, self.output_precision)
+      calc_time = time.time()
+      clusters = calc_clusters(oklab, self.output_precision, seed=self.seed)
+      cluster_time = time.time()
+    except Exception as e:
+      print(f"Error calculating statistics for {form}: {e}")
+      return e
     print(
       f"Completed {form}: ",
       f"{(cluster_time - start_time):.02f}s in total,",
@@ -178,11 +188,15 @@ class Ingester:
     return {
       **self.lookup[form],
       "total": asdict(total),
-      "clusters": clusters,
+      "clusters": [asdict(c) for c in clusters],
     }
 
 
 def output_db(results: list[dict], db_file: str):
+  if db_file == "-":
+    print(json.dumps(results, indent=2))
+    return
+
   with open(db_file, "w") as output:
     output.write("const database = [\n")
     for entry in results:
@@ -192,30 +206,65 @@ def output_db(results: list[dict], db_file: str):
     output.write("]\n")
 
 
-def run_ingest(dex: dict, image_dir: str, exec: Executor, db_file: str):
-  ingest = Ingester(dex)
-  to_process = search_files(image_dir)
+def run_ingest(ingest: Ingester, filenames: list[Path], exec: Executor, db_file: str):
+  to_process = defaultdict(list)
+  missing = []
+  for path in filenames:
+    if path.is_file():
+      form_name = path.name.rsplit("-", 1)[0]
+      to_process[form_name].append(path)
+    else:
+      missing.append(path)
+      print(f"Missing file: {path}")
   start = time.time()
   results = list(exec.map(ingest, to_process.items()))
   end = time.time()
+  success = [r for r in results if not isinstance(r, Exception)]
+  errors = [e for e in results if isinstance(e, Exception)]
   print(
     f"Finished ingest of {len(to_process)} forms",
     f"and {sum(len(fns) for fns in to_process.values())} files",
-    f"in {(end - start):.2f}s"
+    f"in {(end - start):.2f}s",
+    f"with {len(missing)} missing file(s)"
+    f"and {len(errors)} error(s)"
   )
-  output_db(sorted(results, key=lambda e: (e["num"], e["name"])), db_file)
-  print(f"Output {len(results)} entries to {db_file}")
+  for e in errors:
+    print(f"Error: {e}")
+  for m in missing:
+    print(f"Missing: {e}")
+  output_db(sorted(success, key=lambda e: (e["num"], e["name"])), db_file)
+  print(f"Output {len(success)} entries to {db_file}")
 
 
 if __name__ == "__main__":
-  from sys import argv
-  dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
-  image_dir = argv[2] if len(argv) > 2 else "images"
-  seed = int(argv[3]) if len(argv) > 3 else 230308
-  db_file = argv[4] if len(argv) > 4 else "data/latest.db"
+  from argparse import ArgumentParser
+  parser = ArgumentParser(
+    prog="Image Analyzer",
+    description="Analyze and summarize images based on color",
+  )
+  parser.add_argument(
+    "-p", "--precision", type=int, default=8, help="Round output to this many decimal places"
+  )
+  parser.add_argument(
+    "-s", "--seed", type=int, default=230308, help="Clustering seed"
+  )
+  parser.add_argument(
+    "-w", "--workers", type=int, default=4, help="Worker process count"
+  )
+  parser.add_argument(
+    "-o", "--output", default="data/latest.db", help="Database file"
+  )
+  parser.add_argument(
+    "-d", "--pokedex", default="data/pokedex.json", help="Pokedex file"
+  )
+  parser.add_argument("images", metavar="file", type=Path, nargs="+")
 
-  with open(dex_file) as infile:
+  args = parser.parse_args()
+
+  with open(args.pokedex) as infile:
     dex = json.load(infile)
 
-  with ProcessPoolExecutor(4) as pool:
-    run_ingest(dex, image_dir, pool, db_file)
+  ingest = Ingester(dex, args.precision, args.seed)
+
+  with ProcessPoolExecutor(max_workers=args.workers) as pool:
+    run_ingest(ingest, args.images, pool, args.output)