import math import asyncio import multiprocessing import json from collections import defaultdict from io import BytesIO from typing import NamedTuple, Generator from itertools import combinations import numpy as np from PIL import Image from aiohttp import ClientSession from scipy.cluster import vq """ Goals: + Single module + Use OKLab + Improved clustering logic + Parallel, in the same way as anim-ingest + Async requests for downloads + Include more info about the pokemon (form, display name, icon sprite source) + Include megas/gmax/etc, tagged so the UI can filter them * Include more images (get more stills from pokemondb + serebii) * Include shinies * Fallback automatically (try showdown animated, then showdown gen5, then pdb) * Filtering system more explicit and easier to work around * Output a record of ingest for auditing * Automatic retry of a partially failed ingest, using record """ # https://en.wikipedia.org/wiki/SRGB#Transformation linearize_srgb = np.vectorize( lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4) ) delinearize_lrgb = np.vectorize( lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055) ) # https://mina86.com/2019/srgb-xyz-matrix/ RGB_TO_XYZ = np.array([ [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984], [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492], [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952], ]) XYZ_TO_RGB = [ [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795], [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775], [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735], ] # https://bottosson.github.io/posts/oklab/ XYZ_TO_LMS = np.array([ [0.8189330101, 0.3618667424, -0.1288597137], [0.0329845436, 0.9293118715, 0.0361456387], [0.0482003018, 0.2643662691, 0.6338517070], ]) RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS) LMS_TO_OKLAB = np.array([ [0.2104542553, 0.7936177850, -0.0040720468], [1.9779984951, -2.4285922050, 0.4505937099], [0.0259040371, 0.7827717662, -0.8086757660], ]) OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB) def oklab2hex(pixel: np.array) -> str: # no need for a vectorized version, this is only for providing the mean hex return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T)) def srgb2oklab(pixels: np.array) -> np.array: return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T Stats = NamedTuple("Stats", [ ("size", int), ("variance", float), ("stddev", float), ("hex", str), ("Lbar", float), ("abar", float), ("bbar", float), ("Cbar", float), ("hbar", float), ("Lhat", float), ("ahat", float), ("bhat", float), ]) Data = NamedTuple("Data", [ ("total", Stats), ("clusters", list[Stats]), ]) FormInfo = NamedTuple("FormData", [ ("name", str), ("traits", list[str]), ("types", list[str]), ("color", str), ("data", Data | None), ]) Pokemon = NamedTuple("Pokemon", [ ("num", int), ("species", str), ("sprite", str | None), ("forms", list[FormInfo]), ]) def calc_statistics(pixels: np.array) -> Stats: # mean pixel of the image, (L-bar, a-bar, b-bar) mean = pixels.mean(axis=0) # square each component squared = pixels ** 2 # Euclidean norm squared by summing squared components sqnorms = squared.sum(axis=1) # mean pixel of normalized image, (L-hat, a-hat, b-hat) tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0) # variance = mean(||p||^2) - ||mean(p)||^2 variance = sqnorms.mean(axis=0) - sum(mean ** 2) # chroma^2 = a^2 + b^2 chroma = np.sqrt(squared[:, 1:].sum(axis=1)) # hue = atan2(b, a), but we need a circular mean # https://en.wikipedia.org/wiki/Circular_mean#Definition # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).mean(axis=0)) return Stats( size=len(pixels), variance=variance, stddev=math.sqrt(variance), hex=oklab2hex(mean), Lbar=mean[0], abar=mean[1], bbar=mean[2], Cbar=chroma.mean(axis=0), hbar=hue * 180 / math.pi, Lhat=tilt[0], ahat=tilt[1], bhat=tilt[2], ) def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]: means, labels = max( ( # Try k = 2, 3, and 4, and try a few times for each vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i) for k in (2, 3, 4) for i in range(cluster_attempts) ), key=lambda c: # Evaluate clustering by seeing the average distance in the ab-plane # between the centers. Maximizing this means the clusters are highly # distinct, which gives a sense of which k was best. (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2) .sum(axis=1) .mean(axis=0) ) return [calc_statistics(pixels[labels == i]) for i in range(len(means))] def get_pixels(img: Image) -> np.array: rgb = [] for fr in range(getattr(img, "n_frames", 1)): img.seek(fr) rgb += [ [r, g, b] for r, g, b, a in img.convert("RGBA").getdata() if a > 0 and (r, g, b) != (0, 0, 0) ] return srgb2oklab(np.array(rgb)) async def load_image(session: ClientSession, url: str) -> Image.Image: async with session.get(url) as res: return Image.open(BytesIO(await res.read())) async def load_all_images(urls: list[str]) -> list[Image.Image]: async with ClientSession() as session: # TODO error handling return await asyncio.gather(*(load_image(session, url) for url in urls)) def get_data(urls: list[str], seed=0) -> Data: images = asyncio.get_event_loop().run_until_complete(load_all_images(urls)) # TODO error handling pixels = np.concatenate([get_pixels(img) for img in images]) return Data( total=calc_statistics(pixels), clusters=find_clusters(pixels, seed=seed), ) def get_traits(species: str, form: dict) -> list[str]: kind = form["formeKind"] traits = [] if kind in ("mega", "mega-x", "mega-y", "primal"): traits.extend(("mega", "nostart")) if kind in ("gmax", "eternamax", "rapid-strike-gmax"): traits.extend(("gmax", "nostart")) if kind in ("alola", "galar", "hisui", "galar", "paldea"): traits.extend(("regional", kind)) # special cases if species == "Tauros" and "-paldea" in kind: # paldean tauros has dumb names traits.extend(("regional", "paldea")) if species == "Minior" and kind != "meteor": # minior can only start the battle in meteor form traits.append("nostart") if species == "Darmanitan" and "zen" in kind: # darmanitan cannot start in zen form traits.append("nostart") if "galar" in kind: # also there's a galar-zen form to handle traits.extend(("regional", "galar")) if species == "Palafin" and kind == "hero": # palafin can only start in zero form traits.append("nostart") if species == "Gimmighoul" and kind == "roaming": # gimmighoul roaming is only in PGO traits.append("nostart") return list(set(traits)) # https://bulbapedia.bulbagarden.net/wiki/List_of_Pok%C3%A9mon_with_gender_differences # there are some pokemon with notable gender diffs that the dex doesn't cover # judgement calls made arbitrarily GENDER_DIFFS = ( "hippopotas", "hippowdon", "unfezant", "frillish", "jellicent", "pyroar", # meowstic, indeedee, basculegion, oinkologne are already handled in the dex ) def load_pokedex(path: str) -> Generator[Pokemon, None, None]: with open(path) as infile: pkdx_raw = json.load(infile) pkdx = defaultdict(list) for key, entry in pkdx_raw.items(): num = entry["num"] # non-cosmetic forms get separate entries automatically # but keeping the separate unown forms would be ridiculous if key != "unown" and len(cosmetic := entry.get("cosmeticFormes", [])) > 0: cosmetic.append(f'{entry["name"]}-{entry["baseForme"]}') if key == "alcremie": # oh god this thing cosmetic = [ f"{cf}-{sweet}" for cf in cosmetic for sweet in [ "Strawberry", "Berry", "Love", "Star", "Clover", "Flower", "Ribbon", ] ] pkdx[num].extend({ **entry, "forme": cf.replace(" ", "-"), "formeKind": "cosmetic", } for cf in cosmetic) elif key in GENDER_DIFFS: pkdx[num].append({ **entry, "forme": f'{entry["name"]}-M', "formeKind": "cosmetic", }) pkdx[num].append({ **entry, "forme": f'{entry["name"]}-F', "formeKind": "cosmetic", }) else: pkdx[num].append({ **entry, "forme": entry["name"], "formeKind": entry.get("forme", "base").lower(), }) for i in range(1, max(pkdx.keys()) + 1): forms = pkdx[i] # double check there's no skipped entries assert len(forms) > 0 # yield forms species = forms[0].get("baseSpecies", forms[0]["name"]) yield Pokemon( num=i, species=species, sprite=None, # found later forms=[ FormInfo( name=f.get("forme", f["name"]), traits=get_traits(species, f), types=f["types"], color=f["color"], data=None, # found later ) for f in forms ] ) if __name__ == "__main__": from sys import argv dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json" out_file = argv[2] if len(argv) > 2 else "data/database-latest.js" log_file = argv[3] if len(argv) > 2 else "ingest.log" pkdx = list(load_pokedex()) print(json.dumps(pkdx[5], indent=2)) print(json.dumps(pkdx[285], indent=2)) print(json.dumps(pkdx[773], indent=2)) # with multiprocessing.Pool(4) as pool: # yield from pool.imap_unordered(lambda n: get_data(n, seed=seed), pokemon, 100)