import math from pathlib import Path from dataclasses import dataclass, asdict from itertools import combinations import numpy as np from PIL import Image from scipy.cluster import vq # https://en.wikipedia.org/wiki/SRGB#Transformation linearize_srgb = np.vectorize( lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4) ) delinearize_lrgb = np.vectorize( lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055) ) # https://mina86.com/2019/srgb-xyz-matrix/ RGB_TO_XYZ = np.array([ [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984], [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492], [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952], ]) XYZ_TO_RGB = [ [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795], [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775], [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735], ] # https://bottosson.github.io/posts/oklab/ XYZ_TO_LMS = np.array([ [0.8189330101, 0.3618667424, -0.1288597137], [0.0329845436, 0.9293118715, 0.0361456387], [0.0482003018, 0.2643662691, 0.6338517070], ]) RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS) LMS_TO_OKLAB = np.array([ [0.2104542553, 0.7936177850, -0.0040720468], [1.9779984951, -2.4285922050, 0.4505937099], [0.0259040371, 0.7827717662, -0.8086757660], ]) OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB) # round output to this many decimals OUTPUT_PRECISION = 6 def oklab2hex(pixel: np.array) -> str: # no need for a vectorized version, this is only for providing the mean hex return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T)) def srgb2oklab(pixels: np.array) -> np.array: return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T @dataclass class Stats: # points (L, a, b) centroid: list[float] tilt: list[float] # scalar statistics size: int variance: float chroma: float hue: float # sRGB hex code of the centroid hex: str def calc_statistics(pixels: np.array) -> Stats: # Euclidean norm squared by summing squared components sqnorms = (pixels ** 2).sum(axis=1) # centroid, the arithmetic mean pixel of the image centroid = pixels.mean(axis=0) # tilt, the arithmetic mean pixel of normalized image tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0) # variance = mean(||p||^2) - ||mean(p)||^2 variance = sqnorms.mean(axis=0) - sum(centroid ** 2) # chroma^2 = a^2 + b^2 chroma = np.hypot(pixels[:, 1], pixels[:, 2]) # hue = atan2(b, a), but we need a circular mean # https://en.wikipedia.org/wiki/Circular_mean#Definition # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma # and bc atan2(y/c, x/c) = atan2(y, x), this is a sum not a mean hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).sum(axis=0)) return Stats( centroid=list(np.round(centroid, OUTPUT_PRECISION)), tilt=list(np.round(tilt, OUTPUT_PRECISION)), size=len(pixels), variance=round(variance, OUTPUT_PRECISION), chroma=round(chroma.mean(axis=0), OUTPUT_PRECISION), hue=round(hue % (2 * math.pi), OUTPUT_PRECISION), hex=oklab2hex(centroid), ) def calc_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]: means, labels = max( ( # Try k = 2, 3, and 4, and try a few times for each vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i) for k in (2, 3, 4) for i in range(cluster_attempts) ), key=lambda c: # Evaluate clustering by seeing the average distance in the ab-plane # between the centers. Maximizing this means the clusters are highly # distinct, which gives a sense of which k was best. # A different clustering algorithm may be more suited here, but this # is comparatively cheap while still producing reasonable results. (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2) .sum(axis=1) .mean(axis=0) ) return [calc_statistics(pixels[labels == i]) for i in range(len(means))] def get_srgb_pixels(img: Image.Image) -> np.array: rgb = [] for fr in range(getattr(img, "n_frames", 1)): img.seek(fr) rgb += [ [r, g, b] for r, g, b, a in img.convert("RGBA").getdata() if a > 0 and (r, g, b) != (0, 0, 0) ] return np.array(rgb) if __name__ == "__main__": from sys import argv dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json" image_dir = argv[2] if len(argv) > 2 else "images" seed = int(argv[3]) if len(argv) > 3 else 230308 import os from collections import defaultdict to_process = defaultdict(list) for image_filename in os.listdir(image_dir): form_name = image_filename.rsplit("-", maxsplit=1)[0] to_process[form_name].append(Path(image_dir, image_filename)) # TODO multiproc database = [] for form, image_files in to_process.items(): all_pixels = np.concatenate([ get_srgb_pixels(Image.open(fn)) for fn in image_files ]) oklab = srgb2oklab(all_pixels) database.append({ "name": form, # TODO also get dex info - species, color, etc. "total": asdict(calc_statistics(oklab)), "clusters": [asdict(c) for c in calc_clusters(oklab, seed=seed)], }) # TODO real output import json print(json.dumps(database, indent=2))