import math import asyncio import multiprocessing import json from collections import defaultdict from io import BytesIO from typing import NamedTuple, Generator from itertools import combinations import numpy as np from PIL import Image from aiohttp import ClientSession from scipy.cluster import vq """ Goals: + Single module + Use OKLab + Improved clustering logic + Parallel, in the same way as anim-ingest + Async requests for downloads * Include more info about the pokemon (form, display name, icon sprite source) * Include more images (get more stills from pokemondb + serebii) * Include shinies + megas, tagged so the UI can filter them * Fallback automatically (try showdown animated, then showdown gen5, then pdb) * Filtering system more explicit and easier to work around * Output a record of ingest for auditing * Automatic retry of a partially failed ingest, using record """ # https://en.wikipedia.org/wiki/SRGB#Transformation linearize_srgb = np.vectorize( lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4) ) delinearize_lrgb = np.vectorize( lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055) ) # https://mina86.com/2019/srgb-xyz-matrix/ RGB_TO_XYZ = np.array([ [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984], [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492], [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952], ]) XYZ_TO_RGB = [ [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795], [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775], [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735], ] # https://bottosson.github.io/posts/oklab/ XYZ_TO_LMS = np.array([ [0.8189330101, 0.3618667424, -0.1288597137], [0.0329845436, 0.9293118715, 0.0361456387], [0.0482003018, 0.2643662691, 0.6338517070], ]) RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS) LMS_TO_OKLAB = np.array([ [0.2104542553, 0.7936177850, -0.0040720468], [1.9779984951, -2.4285922050, 0.4505937099], [0.0259040371, 0.7827717662, -0.8086757660], ]) OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB) def oklab2hex(pixel: np.array) -> str: # no need for a vectorized version, this is only for providing the mean hex return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T)) def srgb2oklab(pixels: np.array) -> np.array: return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T Stats = NamedTuple("Stats", [ ("size", int), ("variance", float), ("stddev", float), ("hex", str), ("Lbar", float), ("abar", float), ("bbar", float), ("Cbar", float), ("hbar", float), ("Lhat", float), ("ahat", float), ("bhat", float), ]) def calc_statistics(pixels: np.array) -> Stats: # mean pixel of the image, (L-bar, a-bar, b-bar) mean = pixels.mean(axis=0) # square each component squared = pixels ** 2 # Euclidean norm squared by summing squared components sqnorms = squared.sum(axis=1) # mean pixel of normalized image, (L-hat, a-hat, b-hat) tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0) # variance = mean(||p||^2) - ||mean(p)||^2 variance = sqnorms.mean(axis=0) - sum(mean ** 2) # chroma^2 = a^2 + b^2, C-bar = mean(sqrt(a^2 + b^2)) chroma = np.sqrt(squared[:, 1:].sum(axis=1)).mean(axis=0) # hue = atan2(b, a), h-bar = mean(atan2(b, a)) hue = np.arctan2(pixels[:, 2], pixels[:, 1]).mean(axis=0) * 180 / math.pi return Stats( size=len(pixels), variance=variance, stddev=math.sqrt(variance), hex=oklab2hex(mean), Lbar=mean[0], abar=mean[1], bbar=mean[2], Cbar=chroma, hbar=hue, Lhat=tilt[0], ahat=tilt[1], bhat=tilt[2], ) def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]: means, labels = max( ( # Try k = 2, 3, and 4, and try a few times for each vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i) for k in (2, 3, 4) for i in range(cluster_attempts) ), key=lambda c: # Evaluate clustering by seeing the average distance in the ab-plane # between the centers. Maximizing this means the clusters are highly # distinct, which gives a sense of which k was best. (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2) .sum(axis=1) .mean(axis=0) ) return [calc_statistics(pixels[labels == i]) for i in range(len(means))] Data = NamedTuple("Data", [ ("name", str), ("sprite", str), ("traits", list[str]), ("total", Stats), ("clusters", list[Stats]), ]) def get_pixels(img: Image) -> np.array: rgb = [] for fr in range(getattr(img, "n_frames", 1)): img.seek(fr) rgb += [ [r, g, b] for r, g, b, a in img.convert("RGBA").getdata() if a > 0 and (r, g, b) != (0, 0, 0) ] return srgb2oklab(np.array(rgb)) async def load_image(session: ClientSession, url: str) -> Image.Image: async with session.get(url) as res: return Image.open(BytesIO(await res.read())) async def load_all_images(urls: list[str]) -> list[Image.Image]: async with ClientSession() as session: # TODO error handling return await asyncio.gather(*(load_image(session, url) for url in urls)) def get_data(name, seed=0) -> Data: images = asyncio.get_event_loop().run_until_complete(load_all_images([ # TODO source images ])) # TODO error handling pixels = np.concatenate([get_pixels(img) for img in images]) return Data( # TODO name normalization name=name, # TODO sprite URL discovery sprite=f"https://img.pokemondb.net/sprites/sword-shield/icon/{name}.png", # TODO trait analysis traits=[], total=calc_statistics(pixels), clusters=find_clusters(pixels, seed=seed), ) def get_data_for_all(pokemon: list[str], seed=0) -> Generator[Data, None, None]: with multiprocessing.Pool(4) as pool: yield from pool.imap_unordered(lambda n: get_data(n, seed=seed), enumerate(pokemon), 100) def name2id(name: str) -> str: return name.replace(" ", "").replace("-", "").lower() def load_pokedex(path: str) -> dict: with open(path) as infile: pkdx_raw = json.load(infile) pkdx = defaultdict(list) for key, entry in pkdx_raw.items(): num = entry["num"] # non-cosmetic forms get separate entries automatically # but keeping the separate unown forms would be ridiculous if key != "unown" and len(cosmetic := entry.get("cosmeticFormes", [])) > 0: cosmetic.append(f'{key}-{entry["baseForme"].replace(" ", "-")}') if key == "alcremie": # oh god this thing cosmetic = [ f"{cf}-{sweet}" for cf in cosmetic for sweet in [ "Strawberry", "Berry", "Love", "Star", "Clover", "Flower", "Ribbon", ] ] pkdx[num].extend((name2id(cf), { **entry, "forme": cf, }) for cf in cosmetic) else: pkdx[num].append((key, entry)) for i in range(min(pkdx.keys()), max(pkdx.keys()) + 1): # double check there's no skipped entries assert len(pkdx[i]) > 0 return pkdx if __name__ == "__main__": from sys import argv load_pokedex(argv[1] if len(argv) > 1 else "data/pokedex.json")