|
@@ -47,7 +47,7 @@ LMS_TO_OKLAB = np.array([
|
|
|
OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
|
|
|
|
|
|
# round output to this many decimals
|
|
|
-OUTPUT_PRECISION = 6
|
|
|
+OUTPUT_PRECISION = 8
|
|
|
|
|
|
|
|
|
def oklab2hex(pixel: np.array) -> str:
|
|
@@ -73,7 +73,7 @@ class Stats:
|
|
|
hex: str
|
|
|
|
|
|
|
|
|
-def calc_statistics(pixels: np.array) -> Stats:
|
|
|
+def calc_statistics(pixels: np.array, output_precision: int) -> Stats:
|
|
|
# Euclidean norm squared by summing squared components
|
|
|
sqnorms = (pixels ** 2).sum(axis=1)
|
|
|
# centroid, the arithmetic mean pixel of the image
|
|
@@ -91,17 +91,17 @@ def calc_statistics(pixels: np.array) -> Stats:
|
|
|
# and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean
|
|
|
hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).sum(axis=0))
|
|
|
return Stats(
|
|
|
- centroid=list(np.round(centroid, OUTPUT_PRECISION)),
|
|
|
- tilt=list(np.round(tilt, OUTPUT_PRECISION)),
|
|
|
- variance=round(variance, OUTPUT_PRECISION),
|
|
|
- chroma=round(chroma.mean(axis=0), OUTPUT_PRECISION),
|
|
|
- hue=round(hue % (2 * math.pi), OUTPUT_PRECISION),
|
|
|
+ centroid=list(np.round(centroid, output_precision)),
|
|
|
+ tilt=list(np.round(tilt, output_precision)),
|
|
|
+ variance=round(variance, output_precision),
|
|
|
+ chroma=round(chroma.mean(axis=0), output_precision),
|
|
|
+ hue=round(hue % (2 * math.pi), output_precision),
|
|
|
size=len(pixels),
|
|
|
hex=oklab2hex(centroid),
|
|
|
)
|
|
|
|
|
|
|
|
|
-def calc_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
|
|
|
+def calc_clusters(pixels: np.array, output_precision: int, cluster_attempts=5, seed=0) -> list[Stats]:
|
|
|
means, labels = max(
|
|
|
(
|
|
|
# Try k = 2, 3, and 4, and try a few times for each
|
|
@@ -119,7 +119,7 @@ def calc_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
|
|
|
.sum(axis=1)
|
|
|
.mean(axis=0)
|
|
|
)
|
|
|
- return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
|
|
|
+ return [calc_statistics(pixels[labels == i], output_precision) for i in range(len(means))]
|
|
|
|
|
|
|
|
|
def get_srgb_pixels(img: Image.Image) -> np.array:
|
|
@@ -143,7 +143,7 @@ def search_files(image_dir: str) -> dict[str, list[str]]:
|
|
|
|
|
|
|
|
|
class Ingester:
|
|
|
- def __init__(self, dex: dict) -> None:
|
|
|
+ def __init__(self, dex: dict, output_precision: int, seed: int) -> None:
|
|
|
self.lookup = {
|
|
|
form["name"]: {
|
|
|
"num": pkmn["num"],
|
|
@@ -153,20 +153,30 @@ class Ingester:
|
|
|
for pkmn in dex.values()
|
|
|
for form in pkmn["forms"]
|
|
|
}
|
|
|
+ self.seed = seed
|
|
|
+ self.output_precision = output_precision
|
|
|
|
|
|
- def __call__(self, args: tuple[str, list[str]]) -> tuple[Stats, list[Stats]]:
|
|
|
+ def __call__(self, args: tuple[str, list[str]]) -> dict | Exception:
|
|
|
form, filenames = args
|
|
|
print(f"Ingesting {form}...")
|
|
|
start_time = time.time()
|
|
|
- all_pixels = np.concatenate([
|
|
|
- get_srgb_pixels(Image.open(fn)) for fn in filenames
|
|
|
- ])
|
|
|
- oklab = srgb2oklab(all_pixels)
|
|
|
- conv_time = time.time()
|
|
|
- total = calc_statistics(oklab)
|
|
|
- calc_time = time.time()
|
|
|
- clusters = [asdict(c) for c in calc_clusters(oklab, seed=seed)]
|
|
|
- cluster_time = time.time()
|
|
|
+ try:
|
|
|
+ all_pixels = np.concatenate([
|
|
|
+ get_srgb_pixels(Image.open(fn)) for fn in filenames
|
|
|
+ ])
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error loading images for {form}: {e}")
|
|
|
+ return e
|
|
|
+ try:
|
|
|
+ oklab = srgb2oklab(all_pixels)
|
|
|
+ conv_time = time.time()
|
|
|
+ total = calc_statistics(oklab, self.output_precision)
|
|
|
+ calc_time = time.time()
|
|
|
+ clusters = calc_clusters(oklab, self.output_precision, seed=self.seed)
|
|
|
+ cluster_time = time.time()
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error calculating statistics for {form}: {e}")
|
|
|
+ return e
|
|
|
print(
|
|
|
f"Completed {form}: ",
|
|
|
f"{(cluster_time - start_time):.02f}s in total,",
|
|
@@ -178,11 +188,15 @@ class Ingester:
|
|
|
return {
|
|
|
**self.lookup[form],
|
|
|
"total": asdict(total),
|
|
|
- "clusters": clusters,
|
|
|
+ "clusters": [asdict(c) for c in clusters],
|
|
|
}
|
|
|
|
|
|
|
|
|
def output_db(results: list[dict], db_file: str):
|
|
|
+ if db_file == "-":
|
|
|
+ print(json.dumps(results, indent=2))
|
|
|
+ return
|
|
|
+
|
|
|
with open(db_file, "w") as output:
|
|
|
output.write("const database = [\n")
|
|
|
for entry in results:
|
|
@@ -192,30 +206,65 @@ def output_db(results: list[dict], db_file: str):
|
|
|
output.write("]\n")
|
|
|
|
|
|
|
|
|
-def run_ingest(dex: dict, image_dir: str, exec: Executor, db_file: str):
|
|
|
- ingest = Ingester(dex)
|
|
|
- to_process = search_files(image_dir)
|
|
|
+def run_ingest(ingest: Ingester, filenames: list[Path], exec: Executor, db_file: str):
|
|
|
+ to_process = defaultdict(list)
|
|
|
+ missing = []
|
|
|
+ for path in filenames:
|
|
|
+ if path.is_file():
|
|
|
+ form_name = path.name.rsplit("-", 1)[0]
|
|
|
+ to_process[form_name].append(path)
|
|
|
+ else:
|
|
|
+ missing.append(path)
|
|
|
+ print(f"Missing file: {path}")
|
|
|
start = time.time()
|
|
|
results = list(exec.map(ingest, to_process.items()))
|
|
|
end = time.time()
|
|
|
+ success = [r for r in results if not isinstance(r, Exception)]
|
|
|
+ errors = [e for e in results if isinstance(e, Exception)]
|
|
|
print(
|
|
|
f"Finished ingest of {len(to_process)} forms",
|
|
|
f"and {sum(len(fns) for fns in to_process.values())} files",
|
|
|
- f"in {(end - start):.2f}s"
|
|
|
+ f"in {(end - start):.2f}s",
|
|
|
+ f"with {len(missing)} missing file(s)"
|
|
|
+ f"and {len(errors)} error(s)"
|
|
|
)
|
|
|
- output_db(sorted(results, key=lambda e: (e["num"], e["name"])), db_file)
|
|
|
- print(f"Output {len(results)} entries to {db_file}")
|
|
|
+ for e in errors:
|
|
|
+ print(f"Error: {e}")
|
|
|
+ for m in missing:
|
|
|
+ print(f"Missing: {e}")
|
|
|
+ output_db(sorted(success, key=lambda e: (e["num"], e["name"])), db_file)
|
|
|
+ print(f"Output {len(success)} entries to {db_file}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- from sys import argv
|
|
|
- dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
|
|
|
- image_dir = argv[2] if len(argv) > 2 else "images"
|
|
|
- seed = int(argv[3]) if len(argv) > 3 else 230308
|
|
|
- db_file = argv[4] if len(argv) > 4 else "data/latest.db"
|
|
|
+ from argparse import ArgumentParser
|
|
|
+ parser = ArgumentParser(
|
|
|
+ prog="Image Analyzer",
|
|
|
+ description="Analyze and summarize images based on color",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "-p", "--precision", type=int, default=8, help="Round output to this many decimal places"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "-s", "--seed", type=int, default=230308, help="Clustering seed"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "-w", "--workers", type=int, default=4, help="Worker process count"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "-o", "--output", default="data/latest.db", help="Database file"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "-d", "--pokedex", default="data/pokedex.json", help="Pokedex file"
|
|
|
+ )
|
|
|
+ parser.add_argument("images", metavar="file", type=Path, nargs="+")
|
|
|
|
|
|
- with open(dex_file) as infile:
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ with open(args.pokedex) as infile:
|
|
|
dex = json.load(infile)
|
|
|
|
|
|
- with ProcessPoolExecutor(4) as pool:
|
|
|
- run_ingest(dex, image_dir, pool, db_file)
|
|
|
+ ingest = Ingester(dex, args.precision, args.seed)
|
|
|
+
|
|
|
+ with ProcessPoolExecutor(max_workers=args.workers) as pool:
|
|
|
+ run_ingest(ingest, args.images, pool, args.output)
|