|
@@ -0,0 +1,241 @@
|
|
|
+import math
|
|
|
+import asyncio
|
|
|
+import multiprocessing
|
|
|
+import json
|
|
|
+from collections import defaultdict
|
|
|
+from io import BytesIO
|
|
|
+from typing import NamedTuple, Generator
|
|
|
+from itertools import combinations
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+from PIL import Image
|
|
|
+from aiohttp import ClientSession
|
|
|
+from scipy.cluster import vq
|
|
|
+
|
|
|
+"""
|
|
|
+Goals:
|
|
|
+ + Single module
|
|
|
+ + Use OKLab
|
|
|
+ + Improved clustering logic
|
|
|
+ + Parallel, in the same way as anim-ingest
|
|
|
+ + Async requests for downloads
|
|
|
+ * Include more info about the pokemon (form, display name, icon sprite source)
|
|
|
+ * Include more images (get more stills from pokemondb + serebii)
|
|
|
+ * Include shinies + megas, tagged so the UI can filter them
|
|
|
+ * Fallback automatically (try showdown animated, then showdown gen5, then pdb)
|
|
|
+ * Filtering system more explicit and easier to work around
|
|
|
+ * Output a record of ingest for auditing
|
|
|
+ * Automatic retry of a partially failed ingest, using record
|
|
|
+"""
|
|
|
+
|
|
|
+# https://en.wikipedia.org/wiki/SRGB#Transformation
|
|
|
+linearize_srgb = np.vectorize(
|
|
|
+ lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
|
|
|
+)
|
|
|
+delinearize_lrgb = np.vectorize(
|
|
|
+ lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
|
|
|
+)
|
|
|
+# https://mina86.com/2019/srgb-xyz-matrix/
|
|
|
+RGB_TO_XYZ = np.array([
|
|
|
+ [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
|
|
|
+ [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
|
|
|
+ [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
|
|
|
+])
|
|
|
+XYZ_TO_RGB = [
|
|
|
+ [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
|
|
|
+ [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
|
|
|
+ [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
|
|
|
+]
|
|
|
+
|
|
|
+# https://bottosson.github.io/posts/oklab/
|
|
|
+XYZ_TO_LMS = np.array([
|
|
|
+ [0.8189330101, 0.3618667424, -0.1288597137],
|
|
|
+ [0.0329845436, 0.9293118715, 0.0361456387],
|
|
|
+ [0.0482003018, 0.2643662691, 0.6338517070],
|
|
|
+])
|
|
|
+RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
|
|
|
+LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
|
|
|
+LMS_TO_OKLAB = np.array([
|
|
|
+ [0.2104542553, 0.7936177850, -0.0040720468],
|
|
|
+ [1.9779984951, -2.4285922050, 0.4505937099],
|
|
|
+ [0.0259040371, 0.7827717662, -0.8086757660],
|
|
|
+])
|
|
|
+OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
|
|
|
+
|
|
|
+
|
|
|
+def oklab2hex(pixel: np.array) -> str:
|
|
|
+ # no need for a vectorized version, this is only for providing the mean hex
|
|
|
+ return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
|
|
|
+
|
|
|
+
|
|
|
+def srgb2oklab(pixels: np.array) -> np.array:
|
|
|
+ return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
|
|
|
+
|
|
|
+
|
|
|
+Stats = NamedTuple("Stats", [
|
|
|
+ ("size", int),
|
|
|
+ ("variance", float),
|
|
|
+ ("stddev", float),
|
|
|
+ ("hex", str),
|
|
|
+ ("Lbar", float),
|
|
|
+ ("abar", float),
|
|
|
+ ("bbar", float),
|
|
|
+ ("Cbar", float),
|
|
|
+ ("hbar", float),
|
|
|
+ ("Lhat", float),
|
|
|
+ ("ahat", float),
|
|
|
+ ("bhat", float),
|
|
|
+])
|
|
|
+
|
|
|
+
|
|
|
+def calc_statistics(pixels: np.array) -> Stats:
|
|
|
+ # mean pixel of the image, (L-bar, a-bar, b-bar)
|
|
|
+ mean = pixels.mean(axis=0)
|
|
|
+ # square each component
|
|
|
+ squared = pixels ** 2
|
|
|
+ # Euclidean norm squared by summing squared components
|
|
|
+ sqnorms = squared.sum(axis=1)
|
|
|
+ # mean pixel of normalized image, (L-hat, a-hat, b-hat)
|
|
|
+ tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0)
|
|
|
+ # variance = mean(||p||^2) - ||mean(p)||^2
|
|
|
+ variance = sqnorms.mean(axis=0) - sum(mean ** 2)
|
|
|
+ # chroma^2 = a^2 + b^2, C-bar = mean(sqrt(a^2 + b^2))
|
|
|
+ chroma = np.sqrt(squared[:, 1:].sum(axis=1)).mean(axis=0)
|
|
|
+ # hue = atan2(b, a), h-bar = mean(atan2(b, a))
|
|
|
+ hue = np.arctan2(pixels[:, 2], pixels[:, 1]).mean(axis=0) * 180 / math.pi
|
|
|
+ return Stats(
|
|
|
+ size=len(pixels),
|
|
|
+ variance=variance,
|
|
|
+ stddev=math.sqrt(variance),
|
|
|
+ hex=oklab2hex(mean),
|
|
|
+ Lbar=mean[0],
|
|
|
+ abar=mean[1],
|
|
|
+ bbar=mean[2],
|
|
|
+ Cbar=chroma,
|
|
|
+ hbar=hue,
|
|
|
+ Lhat=tilt[0],
|
|
|
+ ahat=tilt[1],
|
|
|
+ bhat=tilt[2],
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
|
|
|
+ means, labels = max(
|
|
|
+ (
|
|
|
+ # Try k = 2, 3, and 4, and try a few times for each
|
|
|
+ vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
|
|
|
+ for k in (2, 3, 4)
|
|
|
+ for i in range(cluster_attempts)
|
|
|
+ ),
|
|
|
+ key=lambda c:
|
|
|
+ # Evaluate clustering by seeing the average distance in the ab-plane
|
|
|
+ # between the centers. Maximizing this means the clusters are highly
|
|
|
+ # distinct, which gives a sense of which k was best.
|
|
|
+ (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2)
|
|
|
+ .sum(axis=1)
|
|
|
+ .mean(axis=0)
|
|
|
+ )
|
|
|
+ return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
|
|
|
+
|
|
|
+
|
|
|
+Data = NamedTuple("Data", [
|
|
|
+ ("name", str),
|
|
|
+ ("sprite", str),
|
|
|
+ ("traits", list[str]),
|
|
|
+ ("total", Stats),
|
|
|
+ ("clusters", list[Stats]),
|
|
|
+])
|
|
|
+
|
|
|
+
|
|
|
+def get_pixels(img: Image) -> np.array:
|
|
|
+ rgb = []
|
|
|
+ for fr in range(getattr(img, "n_frames", 1)):
|
|
|
+ img.seek(fr)
|
|
|
+ rgb += [
|
|
|
+ [r, g, b]
|
|
|
+ for r, g, b, a in img.convert("RGBA").getdata()
|
|
|
+ if a > 0 and (r, g, b) != (0, 0, 0)
|
|
|
+ ]
|
|
|
+ return srgb2oklab(np.array(rgb))
|
|
|
+
|
|
|
+
|
|
|
+async def load_image(session: ClientSession, url: str) -> Image.Image:
|
|
|
+ async with session.get(url) as res:
|
|
|
+ return Image.open(BytesIO(await res.read()))
|
|
|
+
|
|
|
+
|
|
|
+async def load_all_images(urls: list[str]) -> list[Image.Image]:
|
|
|
+ async with ClientSession() as session:
|
|
|
+ # TODO error handling
|
|
|
+ return await asyncio.gather(*(load_image(session, url) for url in urls))
|
|
|
+
|
|
|
+
|
|
|
+def get_data(name, seed=0) -> Data:
|
|
|
+ images = asyncio.get_event_loop().run_until_complete(load_all_images([
|
|
|
+ # TODO source images
|
|
|
+ ]))
|
|
|
+
|
|
|
+ # TODO error handling
|
|
|
+
|
|
|
+ pixels = np.concatenate([get_pixels(img) for img in images])
|
|
|
+
|
|
|
+ return Data(
|
|
|
+ # TODO name normalization
|
|
|
+ name=name,
|
|
|
+ # TODO sprite URL discovery
|
|
|
+ sprite=f"https://img.pokemondb.net/sprites/sword-shield/icon/{name}.png",
|
|
|
+ # TODO trait analysis
|
|
|
+ traits=[],
|
|
|
+ total=calc_statistics(pixels),
|
|
|
+ clusters=find_clusters(pixels, seed=seed),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def get_data_for_all(pokemon: list[str], seed=0) -> Generator[Data, None, None]:
|
|
|
+ with multiprocessing.Pool(4) as pool:
|
|
|
+ yield from pool.imap_unordered(lambda n: get_data(n, seed=seed), enumerate(pokemon), 100)
|
|
|
+
|
|
|
+
|
|
|
+def name2id(name: str) -> str:
|
|
|
+ return name.replace(" ", "").replace("-", "").lower()
|
|
|
+
|
|
|
+
|
|
|
+def load_pokedex(path: str) -> dict:
|
|
|
+ with open(path) as infile:
|
|
|
+ pkdx_raw = json.load(infile)
|
|
|
+
|
|
|
+ pkdx = defaultdict(list)
|
|
|
+
|
|
|
+ for key, entry in pkdx_raw.items():
|
|
|
+ num = entry["num"]
|
|
|
+ # non-cosmetic forms get separate entries automatically
|
|
|
+ # but keeping the separate unown forms would be ridiculous
|
|
|
+ if key != "unown" and len(cosmetic := entry.get("cosmeticFormes", [])) > 0:
|
|
|
+ cosmetic.append(f'{key}-{entry["baseForme"].replace(" ", "-")}')
|
|
|
+ if key == "alcremie":
|
|
|
+ # oh god this thing
|
|
|
+ cosmetic = [
|
|
|
+ f"{cf}-{sweet}"
|
|
|
+ for cf in cosmetic
|
|
|
+ for sweet in [
|
|
|
+ "Strawberry", "Berry", "Love", "Star",
|
|
|
+ "Clover", "Flower", "Ribbon",
|
|
|
+ ]
|
|
|
+ ]
|
|
|
+ pkdx[num].extend((name2id(cf), {
|
|
|
+ **entry,
|
|
|
+ "forme": cf,
|
|
|
+ }) for cf in cosmetic)
|
|
|
+ else:
|
|
|
+ pkdx[num].append((key, entry))
|
|
|
+
|
|
|
+ for i in range(min(pkdx.keys()), max(pkdx.keys()) + 1):
|
|
|
+ # double check there's no skipped entries
|
|
|
+ assert len(pkdx[i]) > 0
|
|
|
+
|
|
|
+ return pkdx
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ from sys import argv
|
|
|
+ load_pokedex(argv[1] if len(argv) > 1 else "data/pokedex.json")
|