|
@@ -1,7 +1,12 @@
|
|
|
import math
|
|
|
+import os
|
|
|
+import time
|
|
|
+from collections import defaultdict
|
|
|
+from concurrent.futures import ProcessPoolExecutor
|
|
|
from pathlib import Path
|
|
|
from dataclasses import dataclass, asdict
|
|
|
from itertools import combinations
|
|
|
+from typing import Callable
|
|
|
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
@@ -129,34 +134,87 @@ def get_srgb_pixels(img: Image.Image) -> np.array:
|
|
|
return np.array(rgb)
|
|
|
|
|
|
|
|
|
+def search_files(image_dir: str) -> dict[str, list[str]]:
|
|
|
+ files = defaultdict(list)
|
|
|
+ for image_filename in os.listdir(image_dir):
|
|
|
+ form_name = image_filename.rsplit("-", maxsplit=1)[0]
|
|
|
+ files[form_name].append(Path(image_dir, image_filename))
|
|
|
+ return files
|
|
|
+
|
|
|
+
|
|
|
+class Ingester:
|
|
|
+ def __init__(self, dex: dict) -> None:
|
|
|
+ self.lookup = {
|
|
|
+ form["name"]: {
|
|
|
+ "num": pkmn["num"],
|
|
|
+ "species": pkmn["species"],
|
|
|
+ **form,
|
|
|
+ }
|
|
|
+ for pkmn in dex.values()
|
|
|
+ for form in pkmn["forms"]
|
|
|
+ }
|
|
|
+
|
|
|
+ def __call__(self, args: tuple[str, list[str]]) -> tuple[Stats, list[Stats]]:
|
|
|
+ form, filenames = args
|
|
|
+ print(f"Ingesting {form}...")
|
|
|
+ start_time = time.time()
|
|
|
+ all_pixels = np.concatenate([
|
|
|
+ get_srgb_pixels(Image.open(fn)) for fn in filenames
|
|
|
+ ])
|
|
|
+ oklab = srgb2oklab(all_pixels)
|
|
|
+ conv_time = time.time()
|
|
|
+ total = calc_statistics(oklab)
|
|
|
+ calc_time = time.time()
|
|
|
+ clusters = [asdict(c) for c in calc_clusters(oklab, seed=seed)]
|
|
|
+ cluster_time = time.time()
|
|
|
+ print(
|
|
|
+ f"Completed {form}: ",
|
|
|
+ f"{(cluster_time - start_time):.02f}s in total,",
|
|
|
+ f"{(cluster_time - calc_time):.02f}s on clustering,",
|
|
|
+ f"{(calc_time - conv_time):.02f}s on total calcs,",
|
|
|
+ f"{(conv_time - start_time):.02f}s on read and conversion,",
|
|
|
+ f"centroid {total.hex} and {len(clusters)} clusters"
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ **self.lookup[form],
|
|
|
+ "total": asdict(total),
|
|
|
+ "clusters": clusters,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def output_db(results: list[dict], db_file: str):
|
|
|
+ with open(db_file, "w") as output:
|
|
|
+ output.write("const database = [\n")
|
|
|
+ for entry in results:
|
|
|
+ output.write(" ")
|
|
|
+ output.write(json.dumps(entry))
|
|
|
+ output.write(",\n")
|
|
|
+ output.write("]\n")
|
|
|
+
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
from sys import argv
|
|
|
dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
|
|
|
image_dir = argv[2] if len(argv) > 2 else "images"
|
|
|
seed = int(argv[3]) if len(argv) > 3 else 230308
|
|
|
+ db_file = argv[4] if len(argv) > 4 else "data/latest.db"
|
|
|
|
|
|
- import os
|
|
|
- from collections import defaultdict
|
|
|
-
|
|
|
- to_process = defaultdict(list)
|
|
|
- for image_filename in os.listdir(image_dir):
|
|
|
- form_name = image_filename.rsplit("-", maxsplit=1)[0]
|
|
|
- to_process[form_name].append(Path(image_dir, image_filename))
|
|
|
-
|
|
|
- # TODO multiproc
|
|
|
- database = []
|
|
|
- for form, image_files in to_process.items():
|
|
|
- all_pixels = np.concatenate([
|
|
|
- get_srgb_pixels(Image.open(fn)) for fn in image_files
|
|
|
- ])
|
|
|
- oklab = srgb2oklab(all_pixels)
|
|
|
- database.append({
|
|
|
- "name": form,
|
|
|
- # TODO also get dex info - species, color, etc.
|
|
|
- "total": asdict(calc_statistics(oklab)),
|
|
|
- "clusters": [asdict(c) for c in calc_clusters(oklab, seed=seed)],
|
|
|
- })
|
|
|
-
|
|
|
- # TODO real output
|
|
|
import json
|
|
|
- print(json.dumps(database, indent=2))
|
|
|
+ with open(dex_file) as infile:
|
|
|
+ dex = json.load(infile)
|
|
|
+
|
|
|
+ ingest = Ingester(dex)
|
|
|
+ to_process = search_files(image_dir)
|
|
|
+
|
|
|
+ start = time.time()
|
|
|
+ with ProcessPoolExecutor(4) as pool:
|
|
|
+ results = list(pool.map(ingest, to_process.items()))
|
|
|
+ end = time.time()
|
|
|
+ print(
|
|
|
+ f"Finished ingest of {len(to_process)} forms",
|
|
|
+ f"and {sum(len(fns) for fns in to_process.values())} files",
|
|
|
+ f"in {(end - start):.2f}s"
|
|
|
+ )
|
|
|
+
|
|
|
+ output_db(sorted(results, key=lambda e: (e["num"], e["name"])), db_file)
|
|
|
+ print(f"Output {len(results)} entries to {db_file}")
|