ingest2.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import math
  2. import asyncio
  3. import multiprocessing
  4. import json
  5. from collections import defaultdict
  6. from io import BytesIO
  7. from typing import NamedTuple, Generator
  8. from itertools import combinations
  9. import numpy as np
  10. from PIL import Image
  11. from aiohttp import ClientSession
  12. from scipy.cluster import vq
  13. """
  14. Goals:
  15. + Single module
  16. + Use OKLab
  17. + Improved clustering logic
  18. + Parallel, in the same way as anim-ingest
  19. + Async requests for downloads
  20. * Include more info about the pokemon (form, display name, icon sprite source)
  21. * Include more images (get more stills from pokemondb + serebii)
  22. * Include shinies + megas, tagged so the UI can filter them
  23. * Fallback automatically (try showdown animated, then showdown gen5, then pdb)
  24. * Filtering system more explicit and easier to work around
  25. * Output a record of ingest for auditing
  26. * Automatic retry of a partially failed ingest, using record
  27. """
  28. # https://en.wikipedia.org/wiki/SRGB#Transformation
  29. linearize_srgb = np.vectorize(
  30. lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
  31. )
  32. delinearize_lrgb = np.vectorize(
  33. lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
  34. )
  35. # https://mina86.com/2019/srgb-xyz-matrix/
  36. RGB_TO_XYZ = np.array([
  37. [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
  38. [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
  39. [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
  40. ])
  41. XYZ_TO_RGB = [
  42. [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
  43. [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
  44. [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
  45. ]
  46. # https://bottosson.github.io/posts/oklab/
  47. XYZ_TO_LMS = np.array([
  48. [0.8189330101, 0.3618667424, -0.1288597137],
  49. [0.0329845436, 0.9293118715, 0.0361456387],
  50. [0.0482003018, 0.2643662691, 0.6338517070],
  51. ])
  52. RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
  53. LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
  54. LMS_TO_OKLAB = np.array([
  55. [0.2104542553, 0.7936177850, -0.0040720468],
  56. [1.9779984951, -2.4285922050, 0.4505937099],
  57. [0.0259040371, 0.7827717662, -0.8086757660],
  58. ])
  59. OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
  60. def oklab2hex(pixel: np.array) -> str:
  61. # no need for a vectorized version, this is only for providing the mean hex
  62. return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
  63. def srgb2oklab(pixels: np.array) -> np.array:
  64. return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
  65. Stats = NamedTuple("Stats", [
  66. ("size", int),
  67. ("variance", float),
  68. ("stddev", float),
  69. ("hex", str),
  70. ("Lbar", float),
  71. ("abar", float),
  72. ("bbar", float),
  73. ("Cbar", float),
  74. ("hbar", float),
  75. ("Lhat", float),
  76. ("ahat", float),
  77. ("bhat", float),
  78. ])
  79. def calc_statistics(pixels: np.array) -> Stats:
  80. # mean pixel of the image, (L-bar, a-bar, b-bar)
  81. mean = pixels.mean(axis=0)
  82. # square each component
  83. squared = pixels ** 2
  84. # Euclidean norm squared by summing squared components
  85. sqnorms = squared.sum(axis=1)
  86. # mean pixel of normalized image, (L-hat, a-hat, b-hat)
  87. tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0)
  88. # variance = mean(||p||^2) - ||mean(p)||^2
  89. variance = sqnorms.mean(axis=0) - sum(mean ** 2)
  90. # chroma^2 = a^2 + b^2, C-bar = mean(sqrt(a^2 + b^2))
  91. chroma = np.sqrt(squared[:, 1:].sum(axis=1)).mean(axis=0)
  92. # hue = atan2(b, a), h-bar = mean(atan2(b, a))
  93. hue = np.arctan2(pixels[:, 2], pixels[:, 1]).mean(axis=0) * 180 / math.pi
  94. return Stats(
  95. size=len(pixels),
  96. variance=variance,
  97. stddev=math.sqrt(variance),
  98. hex=oklab2hex(mean),
  99. Lbar=mean[0],
  100. abar=mean[1],
  101. bbar=mean[2],
  102. Cbar=chroma,
  103. hbar=hue,
  104. Lhat=tilt[0],
  105. ahat=tilt[1],
  106. bhat=tilt[2],
  107. )
  108. def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
  109. means, labels = max(
  110. (
  111. # Try k = 2, 3, and 4, and try a few times for each
  112. vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
  113. for k in (2, 3, 4)
  114. for i in range(cluster_attempts)
  115. ),
  116. key=lambda c:
  117. # Evaluate clustering by seeing the average distance in the ab-plane
  118. # between the centers. Maximizing this means the clusters are highly
  119. # distinct, which gives a sense of which k was best.
  120. (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2)
  121. .sum(axis=1)
  122. .mean(axis=0)
  123. )
  124. return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
  125. Data = NamedTuple("Data", [
  126. ("name", str),
  127. ("sprite", str),
  128. ("traits", list[str]),
  129. ("total", Stats),
  130. ("clusters", list[Stats]),
  131. ])
  132. def get_pixels(img: Image) -> np.array:
  133. rgb = []
  134. for fr in range(getattr(img, "n_frames", 1)):
  135. img.seek(fr)
  136. rgb += [
  137. [r, g, b]
  138. for r, g, b, a in img.convert("RGBA").getdata()
  139. if a > 0 and (r, g, b) != (0, 0, 0)
  140. ]
  141. return srgb2oklab(np.array(rgb))
  142. async def load_image(session: ClientSession, url: str) -> Image.Image:
  143. async with session.get(url) as res:
  144. return Image.open(BytesIO(await res.read()))
  145. async def load_all_images(urls: list[str]) -> list[Image.Image]:
  146. async with ClientSession() as session:
  147. # TODO error handling
  148. return await asyncio.gather(*(load_image(session, url) for url in urls))
  149. def get_data(name, seed=0) -> Data:
  150. images = asyncio.get_event_loop().run_until_complete(load_all_images([
  151. # TODO source images
  152. ]))
  153. # TODO error handling
  154. pixels = np.concatenate([get_pixels(img) for img in images])
  155. return Data(
  156. # TODO name normalization
  157. name=name,
  158. # TODO sprite URL discovery
  159. sprite=f"https://img.pokemondb.net/sprites/sword-shield/icon/{name}.png",
  160. # TODO trait analysis
  161. traits=[],
  162. total=calc_statistics(pixels),
  163. clusters=find_clusters(pixels, seed=seed),
  164. )
  165. def get_data_for_all(pokemon: list[str], seed=0) -> Generator[Data, None, None]:
  166. with multiprocessing.Pool(4) as pool:
  167. yield from pool.imap_unordered(lambda n: get_data(n, seed=seed), enumerate(pokemon), 100)
  168. def name2id(name: str) -> str:
  169. return name.replace(" ", "").replace("-", "").lower()
  170. def load_pokedex(path: str) -> dict:
  171. with open(path) as infile:
  172. pkdx_raw = json.load(infile)
  173. pkdx = defaultdict(list)
  174. for key, entry in pkdx_raw.items():
  175. num = entry["num"]
  176. # non-cosmetic forms get separate entries automatically
  177. # but keeping the separate unown forms would be ridiculous
  178. if key != "unown" and len(cosmetic := entry.get("cosmeticFormes", [])) > 0:
  179. cosmetic.append(f'{key}-{entry["baseForme"].replace(" ", "-")}')
  180. if key == "alcremie":
  181. # oh god this thing
  182. cosmetic = [
  183. f"{cf}-{sweet}"
  184. for cf in cosmetic
  185. for sweet in [
  186. "Strawberry", "Berry", "Love", "Star",
  187. "Clover", "Flower", "Ribbon",
  188. ]
  189. ]
  190. pkdx[num].extend((name2id(cf), {
  191. **entry,
  192. "forme": cf,
  193. }) for cf in cosmetic)
  194. else:
  195. pkdx[num].append((key, entry))
  196. for i in range(min(pkdx.keys()), max(pkdx.keys()) + 1):
  197. # double check there's no skipped entries
  198. assert len(pkdx[i]) > 0
  199. return pkdx
  200. if __name__ == "__main__":
  201. from sys import argv
  202. load_pokedex(argv[1] if len(argv) > 1 else "data/pokedex.json")