import math import os import time import json from collections import defaultdict from concurrent.futures import Executor from pathlib import Path from dataclasses import dataclass, asdict from itertools import combinations import numpy as np from PIL import Image from scipy.cluster import vq from scipy.spatial.distance import cdist, euclidean # https://en.wikipedia.org/wiki/SRGB#Transformation linearize_srgb = np.vectorize( lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4) ) delinearize_lrgb = np.vectorize( lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055) ) # https://mina86.com/2019/srgb-xyz-matrix/ RGB_TO_XYZ = np.array([ [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984], [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492], [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952], ]) XYZ_TO_RGB = [ [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795], [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775], [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735], ] # https://bottosson.github.io/posts/oklab/ XYZ_TO_LMS = np.array([ [0.8189330101, 0.3618667424, -0.1288597137], [0.0329845436, 0.9293118715, 0.0361456387], [0.0482003018, 0.2643662691, 0.6338517070], ]) RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS) LMS_TO_OKLAB = np.array([ [0.2104542553, 0.7936177850, -0.0040720468], [1.9779984951, -2.4285922050, 0.4505937099], [0.0259040371, 0.7827717662, -0.8086757660], ]) OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB) # round output to this many decimals OUTPUT_PRECISION = 8 def oklab2hex(pixel: np.array) -> str: # no need for a vectorized version, this is only for providing the mean hex return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T)) def srgb2oklab(pixels: np.array) -> np.array: return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T # https://stackoverflow.com/a/30305181 def geometric_median(X: np.array, eps=1e-5) -> np.array: y = np.mean(X, 0) while True: D = cdist(X, [y]) nonzeros = (D != 0)[:, 0] Dinv = 1 / D[nonzeros] Dinvs = np.sum(Dinv) W = Dinv / Dinvs T = np.sum(W * X[nonzeros], 0) num_zeros = len(X) - np.sum(nonzeros) if num_zeros == 0: y1 = T elif num_zeros == len(X): return y else: R = (T - y) * Dinvs r = np.linalg.norm(R) rinv = 0 if r == 0 else num_zeros / r y1 = max(0, 1 - rinv) * T + min(1, rinv) * y if euclidean(y, y1) < eps: return y1 y = y1 @dataclass class Stats: # vector statistics centroid: list[float] # (L, a, b) median: list[float] # (L, a, b) stddev: list[float] # (L, a, b) tilt: list[float] # (L, a, b) chroma: list[float] # (mean, stddev) # scalar statistics hue: float size: int # sRGB hex code of the centroid and median centroidHex: str medianHex: str def calc_statistics(pixels: np.array, output_precision: int) -> Stats: # centroid, the arithmetic mean pixel of the image centroid = pixels.mean(axis=0) # raw second moment, for each channel of the pixels raw_second_moment = (pixels ** 2).mean(axis=0) # stddev, the sqrt of the variance of each channel of the image # variance_x = mean(p_x^2) - mean(p_x)^2 = rsm_x - centroid_x^2 # note, summing those gives a total "variance" in color stddev = np.sqrt(raw_second_moment - centroid ** 2) # tilt, the arithmetic mean pixel of normalized image tilt = (pixels / np.linalg.norm(pixels, axis=1)[:, np.newaxis]).mean(axis=0) # chroma^2 = a^2 + b^2 chromas = np.hypot(pixels[:, 1], pixels[:, 2]) chroma_mean = chromas.mean(axis=0) # variance in chroma is E[a^2 + b^2] - E[sqrt(a^2 + b^2)]^2 # max(0, x) is present to deal with floating point error for extremely low chroma images # a more robust solution could use log space instead but this is fine for this dataset chroma_dev = math.sqrt( max(0, raw_second_moment[1] + raw_second_moment[2] - (chroma_mean ** 2))) # hue = atan2(b, a), but we need a circular mean # https://en.wikipedia.org/wiki/Circular_mean#Definition # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean hue = math.atan2(*(pixels[:, [2, 1]] / chromas[:, np.newaxis]).sum(axis=0)) # approximation of geometric median, primarily for display purposes median = geometric_median(pixels) return Stats( centroid=list(np.round(centroid, output_precision)), median=list(np.round(median, output_precision)), stddev=list(np.round(stddev, output_precision)), tilt=list(np.round(tilt, output_precision)), chroma=[ round(chroma_mean, output_precision), round(chroma_dev, output_precision) ], hue=round(hue % (2 * math.pi), output_precision), size=len(pixels), centroidHex=oklab2hex(centroid), medianHex=oklab2hex(median) ) def calc_clusters(pixels: np.array, output_precision: int, cluster_attempts=5, seed=0) -> list[Stats]: means, labels = min( ( # Try k = 2, 3, and 4, and try a few times for each vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i) for k in (2, 3, 4) for i in range(cluster_attempts) ), key=lambda c: # Evaluate clustering by seeing the average difference in hue angle # between the centers. Maximizing this means the clusters are highly # distinct, which gives a sense of which k was best. # This is computed by normalizing the ab-plane projection of the means, # then applying a dot product to get the cosine of the angle # between them in that plane, which is the hue difference. Minimizing # this maximizes the differences in hues. # A different clustering algorithm may be more suited here, but this # is comparatively cheap while still producing reasonable results. (np.array([ m1 @ m2 for m1, m2 in combinations( c[0][:, 1:] / np.linalg.norm(c[0][:, 1:], axis=1)[:, np.newaxis], 2 ) ])).mean(axis=0) ) return [calc_statistics(pixels[labels == i], output_precision) for i in range(len(means))] def get_srgb_pixels(img: Image.Image) -> np.array: rgb = [] for fr in range(getattr(img, "n_frames", 1)): img.seek(fr) rgb += [ [r, g, b] for r, g, b, a in img.convert("RGBA").getdata() if a > 0 and (r, g, b) != (0, 0, 0) ] return np.array(rgb) def log(*a, **kw): print(*a, **kw, flush=True) class Ingester: def __init__(self, dex: dict, output_precision: int, seed: int) -> None: self.lookup = { form["name"]: { "num": pkmn["num"], "species": pkmn["species"], **form, } for pkmn in dex.values() for form in pkmn["forms"] } self.seed = seed self.output_precision = output_precision def __call__(self, args: tuple[str, list[str]]) -> dict | Exception: form, filenames = args log(f"Ingesting {form}...") start_time = time.time() try: all_pixels = np.concatenate([ get_srgb_pixels(Image.open(fn)) for fn in filenames ]) except Exception as e: log(f"Error loading images for {form}: {e}") return e try: oklab = srgb2oklab(all_pixels) conv_time = time.time() total = calc_statistics(oklab, self.output_precision) calc_time = time.time() clusters = calc_clusters(oklab, self.output_precision, seed=self.seed) cluster_time = time.time() except Exception as e: log(f"Error calculating statistics for {form}: {e}") return e log( f"Completed {form}: ", f"{(cluster_time - start_time):.02f}s in total,", f"{(cluster_time - calc_time):.02f}s on clustering,", f"{(calc_time - conv_time):.02f}s on total calcs,", f"{(conv_time - start_time):.02f}s on read and conversion,", f"median {total.medianHex} and {len(clusters)} clusters" ) return { **self.lookup[form], "total": asdict(total), "clusters": [asdict(c) for c in clusters], } def output_db(results: list[dict], db_file: str): if db_file == "-": log(json.dumps(results, indent=2)) return with open(db_file, "w") as output: output.write("const database = [\n") for entry in results: output.write(" ") output.write(json.dumps(entry)) output.write(",\n") output.write("]\n") def run_ingest(ingest: Ingester, filenames: list[Path], exec: Executor, db_file: str): to_process = defaultdict(list) missing = [] for path in filenames: if path.is_file(): form_name = path.name.rsplit("-", 1)[0] to_process[form_name].append(path) else: missing.append(path) log(f"Missing file: {path}") start = time.time() results = list(exec.map(ingest, to_process.items())) end = time.time() success = [r for r in results if not isinstance(r, Exception)] errors = [e for e in results if isinstance(e, Exception)] log( f"Finished ingest of {len(to_process)} forms", f"and {sum(len(fns) for fns in to_process.values())} files", f"in {(end - start):.2f}s", f"with {len(missing)} missing file(s)", f"and {len(errors)} error(s)" ) for e in errors: log(f"Error: {e}") for m in missing: log(f"Missing: {m}") output_db(sorted(success, key=lambda e: (e["num"], e["name"])), db_file) log(f"Output {len(success)} entries to {db_file}") if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser( prog="Image Analyzer", description="Analyze and summarize images based on color", ) parser.add_argument( "-p", "--precision", type=int, default=8, help="Round output to this many decimal places" ) parser.add_argument( "-s", "--seed", type=int, default=230308, help="Clustering seed" ) parser.add_argument( "-w", "--workers", type=int, default=4, help="Worker process count" ) parser.add_argument( "-o", "--output", default="data/latest.db", help="Database file" ) parser.add_argument( "-d", "--pokedex", default="data/pokedex.json", help="Pokedex file" ) parser.add_argument( "--threading", action="store_true", help="Use threads instead of multiproc (slower but more stable on 3.10)" ) parser.add_argument("images", metavar="file", type=Path, nargs="+") args = parser.parse_args() with open(args.pokedex) as infile: dex = json.load(infile) ingest = Ingester(dex, args.precision, args.seed) from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor with (ThreadPoolExecutor if args.threading else ProcessPoolExecutor)(max_workers=args.workers) as pool: run_ingest(ingest, args.images, pool, args.output)