import math import os import time from collections import defaultdict from concurrent.futures import ProcessPoolExecutor from pathlib import Path from dataclasses import dataclass, asdict from itertools import combinations from typing import Callable import numpy as np from PIL import Image from scipy.cluster import vq # https://en.wikipedia.org/wiki/SRGB#Transformation linearize_srgb = np.vectorize( lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4) ) delinearize_lrgb = np.vectorize( lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055) ) # https://mina86.com/2019/srgb-xyz-matrix/ RGB_TO_XYZ = np.array([ [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984], [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492], [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952], ]) XYZ_TO_RGB = [ [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795], [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775], [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735], ] # https://bottosson.github.io/posts/oklab/ XYZ_TO_LMS = np.array([ [0.8189330101, 0.3618667424, -0.1288597137], [0.0329845436, 0.9293118715, 0.0361456387], [0.0482003018, 0.2643662691, 0.6338517070], ]) RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS) LMS_TO_OKLAB = np.array([ [0.2104542553, 0.7936177850, -0.0040720468], [1.9779984951, -2.4285922050, 0.4505937099], [0.0259040371, 0.7827717662, -0.8086757660], ]) OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB) # round output to this many decimals OUTPUT_PRECISION = 6 def oklab2hex(pixel: np.array) -> str: # no need for a vectorized version, this is only for providing the mean hex return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T)) def srgb2oklab(pixels: np.array) -> np.array: return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T @dataclass class Stats: # points (L, a, b) centroid: list[float] tilt: list[float] # scalar statistics variance: float chroma: float hue: float size: int # sRGB hex code of the centroid hex: str def calc_statistics(pixels: np.array) -> Stats: # Euclidean norm squared by summing squared components sqnorms = (pixels ** 2).sum(axis=1) # centroid, the arithmetic mean pixel of the image centroid = pixels.mean(axis=0) # tilt, the arithmetic mean pixel of normalized image tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0) # variance = mean(||p||^2) - ||mean(p)||^2 variance = sqnorms.mean(axis=0) - sum(centroid ** 2) # chroma^2 = a^2 + b^2 chroma = np.hypot(pixels[:, 1], pixels[:, 2]) # hue = atan2(b, a), but we need a circular mean # https://en.wikipedia.org/wiki/Circular_mean#Definition # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).sum(axis=0)) return Stats( centroid=list(np.round(centroid, OUTPUT_PRECISION)), tilt=list(np.round(tilt, OUTPUT_PRECISION)), variance=round(variance, OUTPUT_PRECISION), chroma=round(chroma.mean(axis=0), OUTPUT_PRECISION), hue=round(hue % (2 * math.pi), OUTPUT_PRECISION), size=len(pixels), hex=oklab2hex(centroid), ) def calc_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]: means, labels = max( ( # Try k = 2, 3, and 4, and try a few times for each vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i) for k in (2, 3, 4) for i in range(cluster_attempts) ), key=lambda c: # Evaluate clustering by seeing the average distance in the ab-plane # between the centers. Maximizing this means the clusters are highly # distinct, which gives a sense of which k was best. # A different clustering algorithm may be more suited here, but this # is comparatively cheap while still producing reasonable results. (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2) .sum(axis=1) .mean(axis=0) ) return [calc_statistics(pixels[labels == i]) for i in range(len(means))] def get_srgb_pixels(img: Image.Image) -> np.array: rgb = [] for fr in range(getattr(img, "n_frames", 1)): img.seek(fr) rgb += [ [r, g, b] for r, g, b, a in img.convert("RGBA").getdata() if a > 0 and (r, g, b) != (0, 0, 0) ] return np.array(rgb) def search_files(image_dir: str) -> dict[str, list[str]]: files = defaultdict(list) for image_filename in os.listdir(image_dir): form_name = image_filename.rsplit("-", maxsplit=1)[0] files[form_name].append(Path(image_dir, image_filename)) return files class Ingester: def __init__(self, dex: dict) -> None: self.lookup = { form["name"]: { "num": pkmn["num"], "species": pkmn["species"], **form, } for pkmn in dex.values() for form in pkmn["forms"] } def __call__(self, args: tuple[str, list[str]]) -> tuple[Stats, list[Stats]]: form, filenames = args print(f"Ingesting {form}...") start_time = time.time() all_pixels = np.concatenate([ get_srgb_pixels(Image.open(fn)) for fn in filenames ]) oklab = srgb2oklab(all_pixels) conv_time = time.time() total = calc_statistics(oklab) calc_time = time.time() clusters = [asdict(c) for c in calc_clusters(oklab, seed=seed)] cluster_time = time.time() print( f"Completed {form}: ", f"{(cluster_time - start_time):.02f}s in total,", f"{(cluster_time - calc_time):.02f}s on clustering,", f"{(calc_time - conv_time):.02f}s on total calcs,", f"{(conv_time - start_time):.02f}s on read and conversion,", f"centroid {total.hex} and {len(clusters)} clusters" ) return { **self.lookup[form], "total": asdict(total), "clusters": clusters, } def output_db(results: list[dict], db_file: str): with open(db_file, "w") as output: output.write("const database = [\n") for entry in results: output.write(" ") output.write(json.dumps(entry)) output.write(",\n") output.write("]\n") if __name__ == "__main__": from sys import argv dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json" image_dir = argv[2] if len(argv) > 2 else "images" seed = int(argv[3]) if len(argv) > 3 else 230308 db_file = argv[4] if len(argv) > 4 else "data/latest.db" import json with open(dex_file) as infile: dex = json.load(infile) ingest = Ingester(dex) to_process = search_files(image_dir) start = time.time() with ProcessPoolExecutor(4) as pool: results = list(pool.map(ingest, to_process.items())) end = time.time() print( f"Finished ingest of {len(to_process)} forms", f"and {sum(len(fns) for fns in to_process.values())} files", f"in {(end - start):.2f}s" ) output_db(sorted(results, key=lambda e: (e["num"], e["name"])), db_file) print(f"Output {len(results)} entries to {db_file}")