ingest2.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import math
  2. import asyncio
  3. import multiprocessing
  4. import json
  5. from collections import defaultdict
  6. from io import BytesIO
  7. from typing import NamedTuple, Generator
  8. from itertools import combinations
  9. import numpy as np
  10. from PIL import Image
  11. from aiohttp import ClientSession
  12. from scipy.cluster import vq
  13. """
  14. Goals:
  15. + Single module
  16. + Use OKLab
  17. + Improved clustering logic
  18. + Parallel, in the same way as anim-ingest
  19. + Async requests for downloads
  20. + Include more info about the pokemon (form, display name, icon sprite source)
  21. + Include megas/gmax/etc, tagged so the UI can filter them
  22. * Include more images (get more stills from pokemondb + serebii)
  23. * Include shinies
  24. * Fallback automatically (try showdown animated, then showdown gen5, then pdb)
  25. * Filtering system more explicit and easier to work around
  26. * Output a record of ingest for auditing
  27. * Automatic retry of a partially failed ingest, using record
  28. """
  29. # https://en.wikipedia.org/wiki/SRGB#Transformation
  30. linearize_srgb = np.vectorize(
  31. lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
  32. )
  33. delinearize_lrgb = np.vectorize(
  34. lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
  35. )
  36. # https://mina86.com/2019/srgb-xyz-matrix/
  37. RGB_TO_XYZ = np.array([
  38. [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
  39. [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
  40. [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
  41. ])
  42. XYZ_TO_RGB = [
  43. [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
  44. [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
  45. [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
  46. ]
  47. # https://bottosson.github.io/posts/oklab/
  48. XYZ_TO_LMS = np.array([
  49. [0.8189330101, 0.3618667424, -0.1288597137],
  50. [0.0329845436, 0.9293118715, 0.0361456387],
  51. [0.0482003018, 0.2643662691, 0.6338517070],
  52. ])
  53. RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
  54. LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
  55. LMS_TO_OKLAB = np.array([
  56. [0.2104542553, 0.7936177850, -0.0040720468],
  57. [1.9779984951, -2.4285922050, 0.4505937099],
  58. [0.0259040371, 0.7827717662, -0.8086757660],
  59. ])
  60. OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
  61. def oklab2hex(pixel: np.array) -> str:
  62. # no need for a vectorized version, this is only for providing the mean hex
  63. return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
  64. def srgb2oklab(pixels: np.array) -> np.array:
  65. return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
  66. Stats = NamedTuple("Stats", [
  67. ("size", int),
  68. ("variance", float),
  69. ("stddev", float),
  70. ("hex", str),
  71. ("Lbar", float),
  72. ("abar", float),
  73. ("bbar", float),
  74. ("Cbar", float),
  75. ("hbar", float),
  76. ("Lhat", float),
  77. ("ahat", float),
  78. ("bhat", float),
  79. ])
  80. Data = NamedTuple("Data", [
  81. ("total", Stats),
  82. ("clusters", list[Stats]),
  83. ])
  84. FormInfo = NamedTuple("FormData", [
  85. ("name", str),
  86. ("traits", list[str]),
  87. ("types", list[str]),
  88. ("color", str),
  89. ("data", Data | None),
  90. ])
  91. Pokemon = NamedTuple("Pokemon", [
  92. ("num", int),
  93. ("species", str),
  94. ("sprite", str | None),
  95. ("forms", list[FormInfo]),
  96. ])
  97. def calc_statistics(pixels: np.array) -> Stats:
  98. # mean pixel of the image, (L-bar, a-bar, b-bar)
  99. mean = pixels.mean(axis=0)
  100. # square each component
  101. squared = pixels ** 2
  102. # Euclidean norm squared by summing squared components
  103. sqnorms = squared.sum(axis=1)
  104. # mean pixel of normalized image, (L-hat, a-hat, b-hat)
  105. tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0)
  106. # variance = mean(||p||^2) - ||mean(p)||^2
  107. variance = sqnorms.mean(axis=0) - sum(mean ** 2)
  108. # chroma^2 = a^2 + b^2
  109. chroma = np.sqrt(squared[:, 1:].sum(axis=1))
  110. # hue = atan2(b, a), but we need a circular mean
  111. # https://en.wikipedia.org/wiki/Circular_mean#Definition
  112. # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
  113. # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
  114. hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).mean(axis=0))
  115. return Stats(
  116. size=len(pixels),
  117. variance=variance,
  118. stddev=math.sqrt(variance),
  119. hex=oklab2hex(mean),
  120. Lbar=mean[0],
  121. abar=mean[1],
  122. bbar=mean[2],
  123. Cbar=chroma.mean(axis=0),
  124. hbar=hue * 180 / math.pi,
  125. Lhat=tilt[0],
  126. ahat=tilt[1],
  127. bhat=tilt[2],
  128. )
  129. def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
  130. means, labels = max(
  131. (
  132. # Try k = 2, 3, and 4, and try a few times for each
  133. vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
  134. for k in (2, 3, 4)
  135. for i in range(cluster_attempts)
  136. ),
  137. key=lambda c:
  138. # Evaluate clustering by seeing the average distance in the ab-plane
  139. # between the centers. Maximizing this means the clusters are highly
  140. # distinct, which gives a sense of which k was best.
  141. (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2)
  142. .sum(axis=1)
  143. .mean(axis=0)
  144. )
  145. return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
  146. def get_pixels(img: Image) -> np.array:
  147. rgb = []
  148. for fr in range(getattr(img, "n_frames", 1)):
  149. img.seek(fr)
  150. rgb += [
  151. [r, g, b]
  152. for r, g, b, a in img.convert("RGBA").getdata()
  153. if a > 0 and (r, g, b) != (0, 0, 0)
  154. ]
  155. return srgb2oklab(np.array(rgb))
  156. async def load_image(session: ClientSession, url: str) -> Image.Image:
  157. async with session.get(url) as res:
  158. return Image.open(BytesIO(await res.read()))
  159. async def load_all_images(urls: list[str]) -> list[Image.Image]:
  160. async with ClientSession() as session:
  161. # TODO error handling
  162. return await asyncio.gather(*(load_image(session, url) for url in urls))
  163. def get_data(urls: list[str], seed=0) -> Data:
  164. images = asyncio.get_event_loop().run_until_complete(load_all_images(urls))
  165. # TODO error handling
  166. pixels = np.concatenate([get_pixels(img) for img in images])
  167. return Data(
  168. total=calc_statistics(pixels),
  169. clusters=find_clusters(pixels, seed=seed),
  170. )
  171. def get_traits(species: str, form: dict) -> list[str]:
  172. kind = form["formeKind"]
  173. traits = []
  174. if kind in ("mega", "mega-x", "mega-y", "primal"):
  175. traits.extend(("mega", "nostart"))
  176. if kind in ("gmax", "eternamax", "rapid-strike-gmax"):
  177. traits.extend(("gmax", "nostart"))
  178. if kind in ("alola", "galar", "hisui", "galar", "paldea"):
  179. traits.extend(("regional", kind))
  180. # special cases
  181. if species == "Tauros" and "-paldea" in kind:
  182. # paldean tauros has dumb names
  183. traits.extend(("regional", "paldea"))
  184. if species == "Minior" and kind != "meteor":
  185. # minior can only start the battle in meteor form
  186. traits.append("nostart")
  187. if species == "Darmanitan" and "zen" in kind:
  188. # darmanitan cannot start in zen form
  189. traits.append("nostart")
  190. if "galar" in kind:
  191. # also there's a galar-zen form to handle
  192. traits.extend(("regional", "galar"))
  193. if species == "Palafin" and kind == "hero":
  194. # palafin can only start in zero form
  195. traits.append("nostart")
  196. if species == "Gimmighoul" and kind == "roaming":
  197. # gimmighoul roaming is only in PGO
  198. traits.append("nostart")
  199. return list(set(traits))
  200. # https://bulbapedia.bulbagarden.net/wiki/List_of_Pok%C3%A9mon_with_gender_differences
  201. # there are some pokemon with notable gender diffs that the dex doesn't cover
  202. # judgement calls made arbitrarily
  203. GENDER_DIFFS = (
  204. "hippopotas", "hippowdon",
  205. "unfezant", "frillish", "jellicent",
  206. "pyroar",
  207. # meowstic, indeedee, basculegion, oinkologne are already handled in the dex
  208. )
  209. def load_pokedex(path: str) -> Generator[Pokemon, None, None]:
  210. with open(path) as infile:
  211. pkdx_raw = json.load(infile)
  212. pkdx = defaultdict(list)
  213. for key, entry in pkdx_raw.items():
  214. num = entry["num"]
  215. # non-cosmetic forms get separate entries automatically
  216. # but keeping the separate unown forms would be ridiculous
  217. if key != "unown" and len(cosmetic := entry.get("cosmeticFormes", [])) > 0:
  218. cosmetic.append(f'{entry["name"]}-{entry["baseForme"]}')
  219. if key == "alcremie":
  220. # oh god this thing
  221. cosmetic = [
  222. f"{cf}-{sweet}"
  223. for cf in cosmetic
  224. for sweet in [
  225. "Strawberry", "Berry", "Love", "Star",
  226. "Clover", "Flower", "Ribbon",
  227. ]
  228. ]
  229. pkdx[num].extend({
  230. **entry,
  231. "forme": cf.replace(" ", "-"),
  232. "formeKind": "cosmetic",
  233. } for cf in cosmetic)
  234. elif key in GENDER_DIFFS:
  235. pkdx[num].append({
  236. **entry,
  237. "forme": f'{entry["name"]}-M',
  238. "formeKind": "cosmetic",
  239. })
  240. pkdx[num].append({
  241. **entry,
  242. "forme": f'{entry["name"]}-F',
  243. "formeKind": "cosmetic",
  244. })
  245. else:
  246. pkdx[num].append({
  247. **entry,
  248. "forme": entry["name"],
  249. "formeKind": entry.get("forme", "base").lower(),
  250. })
  251. for i in range(1, max(pkdx.keys()) + 1):
  252. forms = pkdx[i]
  253. # double check there's no skipped entries
  254. assert len(forms) > 0
  255. # yield forms
  256. species = forms[0].get("baseSpecies", forms[0]["name"])
  257. yield Pokemon(
  258. num=i,
  259. species=species,
  260. sprite=None, # found later
  261. forms=[
  262. FormInfo(
  263. name=f.get("forme", f["name"]),
  264. traits=get_traits(species, f),
  265. types=f["types"],
  266. color=f["color"],
  267. data=None, # found later
  268. ) for f in forms
  269. ]
  270. )
  271. if __name__ == "__main__":
  272. from sys import argv
  273. dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
  274. out_file = argv[2] if len(argv) > 2 else "data/database-latest.js"
  275. log_file = argv[3] if len(argv) > 2 else "ingest.log"
  276. pkdx = list(load_pokedex())
  277. print(json.dumps(pkdx[5], indent=2))
  278. print(json.dumps(pkdx[285], indent=2))
  279. print(json.dumps(pkdx[773], indent=2))
  280. # with multiprocessing.Pool(4) as pool:
  281. # yield from pool.imap_unordered(lambda n: get_data(n, seed=seed), pokemon, 100)