123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- import math
- import asyncio
- import multiprocessing
- import json
- from collections import defaultdict
- from io import BytesIO
- from typing import NamedTuple, Generator
- from itertools import combinations
- import numpy as np
- from PIL import Image
- from aiohttp import ClientSession
- from scipy.cluster import vq
- """
- Goals:
- + Single module
- + Use OKLab
- + Improved clustering logic
- + Parallel, in the same way as anim-ingest
- + Async requests for downloads
- + Include more info about the pokemon (form, display name, icon sprite source)
- + Include megas/gmax/etc, tagged so the UI can filter them
- * Include more images (get more stills from pokemondb + serebii)
- * Include shinies
- * Fallback automatically (try showdown animated, then showdown gen5, then pdb)
- * Filtering system more explicit and easier to work around
- * Output a record of ingest for auditing
- * Automatic retry of a partially failed ingest, using record
- """
- # https://en.wikipedia.org/wiki/SRGB#Transformation
- linearize_srgb = np.vectorize(
- lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
- )
- delinearize_lrgb = np.vectorize(
- lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
- )
- # https://mina86.com/2019/srgb-xyz-matrix/
- RGB_TO_XYZ = np.array([
- [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
- [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
- [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
- ])
- XYZ_TO_RGB = [
- [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
- [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
- [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
- ]
- # https://bottosson.github.io/posts/oklab/
- XYZ_TO_LMS = np.array([
- [0.8189330101, 0.3618667424, -0.1288597137],
- [0.0329845436, 0.9293118715, 0.0361456387],
- [0.0482003018, 0.2643662691, 0.6338517070],
- ])
- RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
- LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
- LMS_TO_OKLAB = np.array([
- [0.2104542553, 0.7936177850, -0.0040720468],
- [1.9779984951, -2.4285922050, 0.4505937099],
- [0.0259040371, 0.7827717662, -0.8086757660],
- ])
- OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
- def oklab2hex(pixel: np.array) -> str:
- # no need for a vectorized version, this is only for providing the mean hex
- return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
- def srgb2oklab(pixels: np.array) -> np.array:
- return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
- Stats = NamedTuple("Stats", [
- ("size", int),
- ("variance", float),
- ("stddev", float),
- ("hex", str),
- ("Lbar", float),
- ("abar", float),
- ("bbar", float),
- ("Cbar", float),
- ("hbar", float),
- ("Lhat", float),
- ("ahat", float),
- ("bhat", float),
- ])
- Data = NamedTuple("Data", [
- ("total", Stats),
- ("clusters", list[Stats]),
- ])
- FormInfo = NamedTuple("FormData", [
- ("name", str),
- ("traits", list[str]),
- ("types", list[str]),
- ("color", str),
- ("data", Data | None),
- ])
- Pokemon = NamedTuple("Pokemon", [
- ("num", int),
- ("species", str),
- ("sprite", str | None),
- ("forms", list[FormInfo]),
- ])
- def calc_statistics(pixels: np.array) -> Stats:
- # mean pixel of the image, (L-bar, a-bar, b-bar)
- mean = pixels.mean(axis=0)
- # square each component
- squared = pixels ** 2
- # Euclidean norm squared by summing squared components
- sqnorms = squared.sum(axis=1)
- # mean pixel of normalized image, (L-hat, a-hat, b-hat)
- tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0)
- # variance = mean(||p||^2) - ||mean(p)||^2
- variance = sqnorms.mean(axis=0) - sum(mean ** 2)
- # chroma^2 = a^2 + b^2
- chroma = np.sqrt(squared[:, 1:].sum(axis=1))
- # hue = atan2(b, a), but we need a circular mean
- # https://en.wikipedia.org/wiki/Circular_mean#Definition
- # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
- # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
- hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).mean(axis=0))
- return Stats(
- size=len(pixels),
- variance=variance,
- stddev=math.sqrt(variance),
- hex=oklab2hex(mean),
- Lbar=mean[0],
- abar=mean[1],
- bbar=mean[2],
- Cbar=chroma.mean(axis=0),
- hbar=hue * 180 / math.pi,
- Lhat=tilt[0],
- ahat=tilt[1],
- bhat=tilt[2],
- )
- def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
- means, labels = max(
- (
- # Try k = 2, 3, and 4, and try a few times for each
- vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
- for k in (2, 3, 4)
- for i in range(cluster_attempts)
- ),
- key=lambda c:
- # Evaluate clustering by seeing the average distance in the ab-plane
- # between the centers. Maximizing this means the clusters are highly
- # distinct, which gives a sense of which k was best.
- (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2)
- .sum(axis=1)
- .mean(axis=0)
- )
- return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
- def get_pixels(img: Image) -> np.array:
- rgb = []
- for fr in range(getattr(img, "n_frames", 1)):
- img.seek(fr)
- rgb += [
- [r, g, b]
- for r, g, b, a in img.convert("RGBA").getdata()
- if a > 0 and (r, g, b) != (0, 0, 0)
- ]
- return srgb2oklab(np.array(rgb))
- async def load_image(session: ClientSession, url: str) -> Image.Image:
- async with session.get(url) as res:
- return Image.open(BytesIO(await res.read()))
- async def load_all_images(urls: list[str]) -> list[Image.Image]:
- async with ClientSession() as session:
- # TODO error handling
- return await asyncio.gather(*(load_image(session, url) for url in urls))
- def get_data(urls: list[str], seed=0) -> Data:
- images = asyncio.get_event_loop().run_until_complete(load_all_images(urls))
- # TODO error handling
- pixels = np.concatenate([get_pixels(img) for img in images])
- return Data(
- total=calc_statistics(pixels),
- clusters=find_clusters(pixels, seed=seed),
- )
- def get_traits(species: str, form: dict) -> list[str]:
- kind = form["formeKind"]
- traits = []
- if kind in ("mega", "mega-x", "mega-y", "primal"):
- traits.extend(("mega", "nostart"))
- if kind in ("gmax", "eternamax", "rapid-strike-gmax"):
- traits.extend(("gmax", "nostart"))
- if kind in ("alola", "galar", "hisui", "galar", "paldea"):
- traits.extend(("regional", kind))
- # special cases
- if species == "Tauros" and "-paldea" in kind:
- # paldean tauros has dumb names
- traits.extend(("regional", "paldea"))
- if species == "Minior" and kind != "meteor":
- # minior can only start the battle in meteor form
- traits.append("nostart")
- if species == "Darmanitan" and "zen" in kind:
- # darmanitan cannot start in zen form
- traits.append("nostart")
- if "galar" in kind:
- # also there's a galar-zen form to handle
- traits.extend(("regional", "galar"))
- if species == "Palafin" and kind == "hero":
- # palafin can only start in zero form
- traits.append("nostart")
- if species == "Gimmighoul" and kind == "roaming":
- # gimmighoul roaming is only in PGO
- traits.append("nostart")
- return list(set(traits))
- # https://bulbapedia.bulbagarden.net/wiki/List_of_Pok%C3%A9mon_with_gender_differences
- # there are some pokemon with notable gender diffs that the dex doesn't cover
- # judgement calls made arbitrarily
- GENDER_DIFFS = (
- "hippopotas", "hippowdon",
- "unfezant", "frillish", "jellicent",
- "pyroar",
- # meowstic, indeedee, basculegion, oinkologne are already handled in the dex
- )
- def load_pokedex(path: str) -> Generator[Pokemon, None, None]:
- with open(path) as infile:
- pkdx_raw = json.load(infile)
- pkdx = defaultdict(list)
- for key, entry in pkdx_raw.items():
- num = entry["num"]
- # non-cosmetic forms get separate entries automatically
- # but keeping the separate unown forms would be ridiculous
- if key != "unown" and len(cosmetic := entry.get("cosmeticFormes", [])) > 0:
- cosmetic.append(f'{entry["name"]}-{entry["baseForme"]}')
- if key == "alcremie":
- # oh god this thing
- cosmetic = [
- f"{cf}-{sweet}"
- for cf in cosmetic
- for sweet in [
- "Strawberry", "Berry", "Love", "Star",
- "Clover", "Flower", "Ribbon",
- ]
- ]
- pkdx[num].extend({
- **entry,
- "forme": cf.replace(" ", "-"),
- "formeKind": "cosmetic",
- } for cf in cosmetic)
- elif key in GENDER_DIFFS:
- pkdx[num].append({
- **entry,
- "forme": f'{entry["name"]}-M',
- "formeKind": "cosmetic",
- })
- pkdx[num].append({
- **entry,
- "forme": f'{entry["name"]}-F',
- "formeKind": "cosmetic",
- })
- else:
- pkdx[num].append({
- **entry,
- "forme": entry["name"],
- "formeKind": entry.get("forme", "base").lower(),
- })
- for i in range(1, max(pkdx.keys()) + 1):
- forms = pkdx[i]
- # double check there's no skipped entries
- assert len(forms) > 0
- # yield forms
- species = forms[0].get("baseSpecies", forms[0]["name"])
- yield Pokemon(
- num=i,
- species=species,
- sprite=None, # found later
- forms=[
- FormInfo(
- name=f.get("forme", f["name"]),
- traits=get_traits(species, f),
- types=f["types"],
- color=f["color"],
- data=None, # found later
- ) for f in forms
- ]
- )
- if __name__ == "__main__":
- from sys import argv
- dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
- out_file = argv[2] if len(argv) > 2 else "data/database-latest.js"
- log_file = argv[3] if len(argv) > 2 else "ingest.log"
- pkdx = list(load_pokedex())
- print(json.dumps(pkdx[5], indent=2))
- print(json.dumps(pkdx[285], indent=2))
- print(json.dumps(pkdx[773], indent=2))
- # with multiprocessing.Pool(4) as pool:
- # yield from pool.imap_unordered(lambda n: get_data(n, seed=seed), pokemon, 100)
|