ingest2.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import math
  2. import asyncio
  3. import multiprocessing
  4. import json
  5. from collections import defaultdict
  6. from io import BytesIO
  7. from typing import NamedTuple, Generator
  8. from itertools import combinations
  9. import numpy as np
  10. from PIL import Image
  11. from aiohttp import ClientSession
  12. from scipy.cluster import vq
  13. """
  14. Goals:
  15. + Single module
  16. + Use OKLab
  17. + Improved clustering logic
  18. + Parallel, in the same way as anim-ingest
  19. + Async requests for downloads
  20. * Include more info about the pokemon (form, display name, icon sprite source)
  21. * Include more images (get more stills from pokemondb + serebii)
  22. * Include shinies + megas, tagged so the UI can filter them
  23. * Fallback automatically (try showdown animated, then showdown gen5, then pdb)
  24. * Filtering system more explicit and easier to work around
  25. * Output a record of ingest for auditing
  26. * Automatic retry of a partially failed ingest, using record
  27. """
  28. # https://en.wikipedia.org/wiki/SRGB#Transformation
  29. linearize_srgb = np.vectorize(
  30. lambda v: (v / 12.92) if v <= 0.04045 else (((v + 0.055) / 1.055) ** 2.4)
  31. )
  32. delinearize_lrgb = np.vectorize(
  33. lambda v: (v * 12.92) if v <= 0.0031308 else ((v ** (1 / 2.4)) * 1.055 - 0.055)
  34. )
  35. # https://mina86.com/2019/srgb-xyz-matrix/
  36. RGB_TO_XYZ = np.array([
  37. [33786752 / 81924984, 29295110 / 81924984, 14783675 / 81924984],
  38. [8710647 / 40962492, 29295110 / 40962492, 2956735 / 40962492],
  39. [4751262 / 245774952, 29295110 / 245774952, 233582065 / 245774952],
  40. ])
  41. XYZ_TO_RGB = [
  42. [4277208 / 1319795, -2028932 / 1319795, -658032 / 1319795],
  43. [-70985202 / 73237775, 137391598 / 73237775, 3043398 / 73237775],
  44. [164508 / 2956735, -603196 / 2956735, 3125652 / 2956735],
  45. ]
  46. # https://bottosson.github.io/posts/oklab/
  47. XYZ_TO_LMS = np.array([
  48. [0.8189330101, 0.3618667424, -0.1288597137],
  49. [0.0329845436, 0.9293118715, 0.0361456387],
  50. [0.0482003018, 0.2643662691, 0.6338517070],
  51. ])
  52. RGB_TO_LMS = XYZ_TO_LMS @ RGB_TO_XYZ
  53. LMS_TO_RGB = np.linalg.inv(RGB_TO_LMS)
  54. LMS_TO_OKLAB = np.array([
  55. [0.2104542553, 0.7936177850, -0.0040720468],
  56. [1.9779984951, -2.4285922050, 0.4505937099],
  57. [0.0259040371, 0.7827717662, -0.8086757660],
  58. ])
  59. OKLAB_TO_LMS = np.linalg.inv(LMS_TO_OKLAB)
  60. def oklab2hex(pixel: np.array) -> str:
  61. # no need for a vectorized version, this is only for providing the mean hex
  62. return "#" + "".join(f"{int(x * 255):02X}" for x in delinearize_lrgb(((pixel @ OKLAB_TO_LMS.T) ** 3) @ LMS_TO_RGB.T))
  63. def srgb2oklab(pixels: np.array) -> np.array:
  64. return (linearize_srgb(pixels / 255) @ RGB_TO_LMS.T) ** (1 / 3) @ LMS_TO_OKLAB.T
  65. Stats = NamedTuple("Stats", [
  66. ("size", int),
  67. ("variance", float),
  68. ("stddev", float),
  69. ("hex", str),
  70. ("Lbar", float),
  71. ("abar", float),
  72. ("bbar", float),
  73. ("Cbar", float),
  74. ("hbar", float),
  75. ("Lhat", float),
  76. ("ahat", float),
  77. ("bhat", float),
  78. ])
  79. def calc_statistics(pixels: np.array) -> Stats:
  80. # mean pixel of the image, (L-bar, a-bar, b-bar)
  81. mean = pixels.mean(axis=0)
  82. # square each component
  83. squared = pixels ** 2
  84. # Euclidean norm squared by summing squared components
  85. sqnorms = squared.sum(axis=1)
  86. # mean pixel of normalized image, (L-hat, a-hat, b-hat)
  87. tilt = (pixels / np.sqrt(sqnorms)[:, np.newaxis]).mean(axis=0)
  88. # variance = mean(||p||^2) - ||mean(p)||^2
  89. variance = sqnorms.mean(axis=0) - sum(mean ** 2)
  90. # chroma^2 = a^2 + b^2
  91. chroma = np.sqrt(squared[:, 1:].sum(axis=1))
  92. # hue = atan2(b, a), but we need a circular mean
  93. # https://en.wikipedia.org/wiki/Circular_mean#Definition
  94. # cos(atan2(b, a)) = a / sqrt(a^2 + b^2) = a / chroma
  95. # sin(atan2(b, a)) = b / sqrt(a^2 + b^2) = b / chroma
  96. hue = math.atan2(*(pixels[:, [2, 1]] / chroma[:, np.newaxis]).mean(axis=0))
  97. return Stats(
  98. size=len(pixels),
  99. variance=variance,
  100. stddev=math.sqrt(variance),
  101. hex=oklab2hex(mean),
  102. Lbar=mean[0],
  103. abar=mean[1],
  104. bbar=mean[2],
  105. Cbar=chroma.mean(axis=0),
  106. hbar=hue * 180 / math.pi,
  107. Lhat=tilt[0],
  108. ahat=tilt[1],
  109. bhat=tilt[2],
  110. )
  111. def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
  112. means, labels = max(
  113. (
  114. # Try k = 2, 3, and 4, and try a few times for each
  115. vq.kmeans2(pixels.astype(float), k, minit="++", seed=seed + i)
  116. for k in (2, 3, 4)
  117. for i in range(cluster_attempts)
  118. ),
  119. key=lambda c:
  120. # Evaluate clustering by seeing the average distance in the ab-plane
  121. # between the centers. Maximizing this means the clusters are highly
  122. # distinct, which gives a sense of which k was best.
  123. (np.array([m1 - m2 for m1, m2 in combinations(c[0][:, 1:], 2)]) ** 2)
  124. .sum(axis=1)
  125. .mean(axis=0)
  126. )
  127. return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
  128. Data = NamedTuple("Data", [
  129. ("name", str),
  130. ("sprite", str),
  131. ("traits", list[str]),
  132. ("total", Stats),
  133. ("clusters", list[Stats]),
  134. ])
  135. def get_pixels(img: Image) -> np.array:
  136. rgb = []
  137. for fr in range(getattr(img, "n_frames", 1)):
  138. img.seek(fr)
  139. rgb += [
  140. [r, g, b]
  141. for r, g, b, a in img.convert("RGBA").getdata()
  142. if a > 0 and (r, g, b) != (0, 0, 0)
  143. ]
  144. return srgb2oklab(np.array(rgb))
  145. async def load_image(session: ClientSession, url: str) -> Image.Image:
  146. async with session.get(url) as res:
  147. return Image.open(BytesIO(await res.read()))
  148. async def load_all_images(urls: list[str]) -> list[Image.Image]:
  149. async with ClientSession() as session:
  150. # TODO error handling
  151. return await asyncio.gather(*(load_image(session, url) for url in urls))
  152. def get_data(name, seed=0) -> Data:
  153. images = asyncio.get_event_loop().run_until_complete(load_all_images([
  154. # TODO source images
  155. ]))
  156. # TODO error handling
  157. pixels = np.concatenate([get_pixels(img) for img in images])
  158. return Data(
  159. # TODO name normalization
  160. name=name,
  161. # TODO sprite URL discovery
  162. sprite=f"https://img.pokemondb.net/sprites/sword-shield/icon/{name}.png",
  163. # TODO trait analysis
  164. traits=[],
  165. total=calc_statistics(pixels),
  166. clusters=find_clusters(pixels, seed=seed),
  167. )
  168. def get_data_for_all(pokemon: list[str], seed=0) -> Generator[Data, None, None]:
  169. with multiprocessing.Pool(4) as pool:
  170. yield from pool.imap_unordered(lambda n: get_data(n, seed=seed), enumerate(pokemon), 100)
  171. def name2id(name: str) -> str:
  172. return name.replace(" ", "").replace("-", "").lower()
  173. def load_pokedex(path: str) -> dict:
  174. with open(path) as infile:
  175. pkdx_raw = json.load(infile)
  176. pkdx = defaultdict(list)
  177. for key, entry in pkdx_raw.items():
  178. num = entry["num"]
  179. # non-cosmetic forms get separate entries automatically
  180. # but keeping the separate unown forms would be ridiculous
  181. if key != "unown" and len(cosmetic := entry.get("cosmeticFormes", [])) > 0:
  182. cosmetic.append(f'{key}-{entry["baseForme"].replace(" ", "-")}')
  183. if key == "alcremie":
  184. # oh god this thing
  185. cosmetic = [
  186. f"{cf}-{sweet}"
  187. for cf in cosmetic
  188. for sweet in [
  189. "Strawberry", "Berry", "Love", "Star",
  190. "Clover", "Flower", "Ribbon",
  191. ]
  192. ]
  193. pkdx[num].extend((name2id(cf), {
  194. **entry,
  195. "forme": cf,
  196. }) for cf in cosmetic)
  197. else:
  198. pkdx[num].append((key, entry))
  199. for i in range(min(pkdx.keys()), max(pkdx.keys()) + 1):
  200. # double check there's no skipped entries
  201. assert len(pkdx[i]) > 0
  202. return pkdx
  203. if __name__ == "__main__":
  204. from sys import argv
  205. load_pokedex(argv[1] if len(argv) > 1 else "data/pokedex.json")