anim_ingest.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import io
  2. import math
  3. import itertools
  4. from typing import Callable, NamedTuple
  5. from PIL import Image
  6. from bs4 import BeautifulSoup
  7. from colorspacious import cspace_convert
  8. from scipy.cluster import vq
  9. import requests
  10. import numpy as np
  11. import ingest
  12. cluster_seed = 20220328
  13. cluster_attempts = 10
  14. base = "https://play.pokemonshowdown.com/sprites/ani/"
  15. back_base = "https://play.pokemonshowdown.com/sprites/ani-back/"
  16. # removing all forms of a pokemon, and also pokestars
  17. start_with_filters = [
  18. # no significant visual changes
  19. "arceus-", "silvally-", "genesect-", "pumpkaboo-", "gourgeist-", "unown-", "giratina-",
  20. # cannot start the battle in alternate form
  21. "castform-", "cherrim-", "aegislash-", "xerneas-", "wishiwashi-", "eiscue-", "mimikyu-",
  22. "cramorant-", "morepeko-",
  23. # weird event thing
  24. "greninja-", "eevee-", "pikachu-", "zarude-", "magearna-",
  25. # pokestars
  26. "pokestar",
  27. ]
  28. # removing all forms of a type
  29. end_with_filters = [
  30. "-mega.gif", "-megax.gif", "-megay.gif", "-primal.gif", "-ultra.gif",
  31. "-gmax.gif", "-eternamax.gif", "-totem.gif", "-f.gif", "-b.gif",
  32. ]
  33. # removing pokemon entirely
  34. full_filters = [
  35. # darmanitan zen forms (cannot start in zen)
  36. "darmanitan-galarzen.gif", "darmanitan-zen.gif",
  37. # morpeko hangry (cannot start hangry)
  38. "morpeko-hangry.gif",
  39. # minior core forms (cannot start in anything but -meteor, renamed below)
  40. "minior.gif", "minior-blue.gif", "minior-green.gif", "minior-indigo.gif",
  41. "minior-orange.gif", "minior-violet.gif", "minior-yellow.gif",
  42. # because it is a create-a-pokemon
  43. "astrolotl.gif", "aurumoth.gif", "caribolt.gif", "cawmodore.gif", "chromera.gif", "crucibelle.gif",
  44. "equilibra.gif", "fidgit.gif", "jumbao.gif", "justyke.gif", "kerfluffle.gif", "kitsunoh.gif",
  45. "krilowatt.gif", "malaconda.gif", "miasmaw.gif", "mollux.gif", "naviathan.gif", "necturna.gif",
  46. "pajantom.gif", "plasmanta.gif", "pluffle.gif", "protowatt.gif", "scratchet.gif", "smogecko.gif",
  47. "smoguana.gif", "smokomodo.gif", "snaelstrom.gif", "stratagem.gif", "tomohawk.gif", "volkraken.gif", "voodoom.gif",
  48. # typos/duplicates
  49. "buffalant.gif", "klinklang-back.gif", "krikretot.gif",
  50. "pumpkabo-super.gif", "magcargo%20.gif", "meowstic-female.gif",
  51. "ratatta-a.gif", "ratatta-alola.gif", "raticate-a.gif",
  52. "rotom-h.gif", "rotom-m.gif", "rotom-s.gif", "rotom-w.gif",
  53. # not a pokemon
  54. "substitute.gif",
  55. ]
  56. # force certain pokemon to stay
  57. force_keep = [ "meowstic-f.gif", "unfezant-f.gif", "pyroar-f.gif" ]
  58. # rename certain pokemon after the fact
  59. rename = {
  60. # dash consistency
  61. "nidoranm": "nidoran-m",
  62. "nidoranf": "nidoran-f",
  63. "porygonz": "porygon-z",
  64. "tapubulu": "tapu-bulu",
  65. "tapufini": "tapu-fini",
  66. "tapukoko": "tapu-koko",
  67. "tapulele": "tapu-lele",
  68. "hooh": "ho-oh",
  69. "mimejr": "mime-jr",
  70. "mrmime": "mr-mime",
  71. "mrmime-galar": "mr-mime-galar",
  72. "mrrime": "mr-rime",
  73. "jangmoo": "jangmo-o",
  74. "hakamoo": "hakamo-o",
  75. "kommoo": "kommo-o",
  76. "typenull": "type-null",
  77. "oricorio-pompom": "oricorio-pom-pom",
  78. "necrozma-duskmane": "necrozma-dusk-mane",
  79. "necrozma-dawnwings": "necrozma-dawn-wings",
  80. "toxtricity-lowkey": "toxtricity-low-key",
  81. # rename forms
  82. "shellos": "shellos-west",
  83. "shaymin": "shaymin-land",
  84. "meloetta": "meloetta-aria",
  85. "keldeo": "keldeo-ordinary",
  86. "hoopa": "hoopa-confined",
  87. "burmy": "burmy-plant",
  88. "wormadam": "wormadam-plant",
  89. "deerling": "deerling-spring",
  90. "sawsbuck": "sawsbuck-spring",
  91. "vivillon": "vivillon-meadow",
  92. "basculin": "basculin-redstriped",
  93. "meowstic": "meowstic-male",
  94. "meowstic-f": "meowstic-female",
  95. "pyroar-f": "pyroar-female",
  96. "flabebe": "flabebe-red",
  97. "floette": "floette-red",
  98. "florges": "florges-red",
  99. "minior-meteor": "minior",
  100. "sinistea": "sinistea-phony",
  101. "polteageist": "polteageist-phony",
  102. "gastrodon": "gastrodon-west",
  103. "furfrou": "furfrou-natural",
  104. "wishiwashi": "wishiwashi-school",
  105. "tornadus": "tornadus-incarnate",
  106. "landorus": "landorus-incarnate",
  107. "thundurus": "thundurus-incarnate",
  108. "calyrex-ice": "calyrex-ice-rider",
  109. "calyrex-shadow": "calyrex-shadow-rider",
  110. "urshifu-rapidstrike": "urshifu-rapid-strike",
  111. "zacian": "zacian-hero",
  112. "zamazenta": "zamazenta-hero",
  113. }
  114. def get_all_pokemon() -> list[str]:
  115. soup = BeautifulSoup(requests.get(back_base).text, "html.parser")
  116. gifs = [href for a in soup.find_all("a") if (href := a.get("href")).endswith("gif")]
  117. return [
  118. g[:-4]
  119. for g in gifs
  120. if g in force_keep or (
  121. g not in full_filters
  122. and not any(g.startswith(f) for f in start_with_filters)
  123. and not any(g.endswith(f) for f in end_with_filters)
  124. )
  125. ]
  126. def load_image(base: str, name: str) -> Image:
  127. return Image.open(io.BytesIO(requests.get(base + name + ".gif").content))
  128. def get_all_pixels(im: Image) -> list[tuple[int, int, int]]:
  129. rgb_pixels = []
  130. for fr in range(getattr(im, "n_frames", 1)):
  131. im.seek(fr)
  132. rgb_pixels += [
  133. (r, g, b)
  134. for r, g, b, a in im.convert("RGBA").getdata()
  135. if not ingest.is_outline(r, g, b, a)
  136. ]
  137. return rgb_pixels
  138. def merge_dist_jab(p: np.array, q: np.array) -> float:
  139. pj, pa, pb = p
  140. qj, qa, qb = q
  141. light_diff = abs(pj - qj)
  142. hue_angle = math.acos((pa * qa + pb * qb) / math.sqrt((pa ** 2 + pb ** 2) * (qa ** 2 + qb ** 2))) * 180 / math.pi
  143. return light_diff if hue_angle <= 10 else None
  144. def merge_dist_rgb(p: np.array, q: np.array) -> float:
  145. return merge_dist_jab(*cspace_convert(np.array([p, q]), "sRGB255", "CAM02-UCS"))
  146. def score_clustering_jab(means: list[np.array]) -> float:
  147. score = 0
  148. count = 0
  149. for p, q in itertools.combinations(means, 2):
  150. # squared dist in the a-b plane
  151. _, pa, pb = p
  152. _, qa, qb = q
  153. score += (pa - qa) ** 2 + (pb - qb) ** 2
  154. count += 1
  155. return score / count
  156. def score_clustering_rgb(means: list[np.array]) -> float:
  157. return score_clustering_jab(list(cspace_convert(np.array(means), "sRGB255", "CAM02-UCS")))
  158. Stats = NamedTuple("Stats", [("size", int), ("inertia", float), ("mu", np.array), ("nu", np.array)])
  159. def merge_stats(s1: Stats, s2: Stats) -> Stats:
  160. ts = s1.size + s2.size
  161. f1 = s1.size / ts
  162. f2 = s2.size / ts
  163. return Stats(
  164. size=ts,
  165. inertia=s1.inertia * f1 + s2.inertia * f2,
  166. mu=s1.mu * f1 + s2.mu * f2,
  167. nu=s1.nu * f1 + s2.nu * f2,
  168. )
  169. def flatten_stats(ss: list[Stats], target_len: int = 40) -> list[float]:
  170. to_return = []
  171. for s in ss:
  172. to_return += [s.size, s.inertia, *s.mu, *s.nu]
  173. return to_return + ([0] * (target_len - len(to_return)))
  174. def compute_stats(
  175. pixels: np.array,
  176. clustering_scorer: Callable[[list[np.array]], float],
  177. merge_dist: Callable[[np.array, np.array], float],
  178. ) -> tuple[Stats, Stats, Stats, Stats, Stats]:
  179. total_stats = Stats(
  180. size=len(pixels),
  181. inertia=ingest.inertia(pixels),
  182. mu=ingest.mu(pixels),
  183. nu=ingest.nu(pixels),
  184. )
  185. # run k-means multiple times, for multiple k's, trying to maximize the clustering_scorer
  186. best = None
  187. for k in (2, 3, 4):
  188. for i in range(cluster_attempts):
  189. means, labels = vq.kmeans2(pixels.astype(float), k, minit="++", seed=cluster_seed + i)
  190. score = clustering_scorer(means)
  191. if best is None or best[0] < score:
  192. best = (score, means, labels)
  193. _, best_means, best_labels = best
  194. cluster_stats = []
  195. for i in range(len(best_means)):
  196. cluster_pixels = pixels[best_labels == i]
  197. cluster_stats.append(Stats(
  198. size=len(cluster_pixels),
  199. inertia=ingest.inertia(cluster_pixels),
  200. mu=best_means[i],
  201. nu=ingest.nu(cluster_pixels),
  202. ))
  203. # assuming there are still more than two clusters,
  204. # attempt to merge the closest if they're close enough
  205. if len(cluster_stats) > 2:
  206. # first, find all the options
  207. options = []
  208. for i, j in itertools.combinations(range(len(cluster_stats)), 2):
  209. ci = cluster_stats[i]
  210. cj = cluster_stats[j]
  211. if (dist := merge_dist(ci.mu, cj.mu)) is not None:
  212. rest = [c for k, c in enumerate(cluster_stats) if k not in (i, j)]
  213. options.append((dist, [merge_stats(ci, cj), *rest]))
  214. # if there are multiple options, use the closest,
  215. # otherwise leaves cluster_stats the same since that's the default in the list
  216. if len(options) > 0:
  217. cluster_stats = min(options, key=lambda x: x[0])[1]
  218. return total_stats, *cluster_stats
  219. def get_stats(name: str) -> list[float]:
  220. front = get_all_pixels(load_image(base, name))
  221. back = get_all_pixels(load_image(back_base, name))
  222. rgb_pixels = np.array(front + back)
  223. jab_pixels = cspace_convert(rgb_pixels, "sRGB255", "CAM02-UCS")
  224. jab_stats = flatten_stats(compute_stats(
  225. jab_pixels,
  226. score_clustering_jab,
  227. merge_dist_jab,
  228. ))[1:]
  229. rgb_stats = flatten_stats(compute_stats(
  230. rgb_pixels,
  231. score_clustering_rgb,
  232. merge_dist_rgb,
  233. ))[1:]
  234. return [len(rgb_pixels), *jab_stats, *rgb_stats]
  235. if __name__ == "__main__":
  236. pkmn = get_all_pokemon()
  237. print("Found", len(pkmn), "sprites...")
  238. errors = []
  239. with open("database-v3.js", "w") as outfile:
  240. outfile.write("const databaseV3 = [\n")
  241. for name in pkmn:
  242. print("Ingesting", name, "...")
  243. try:
  244. stats = get_stats(name)
  245. outfile.write(f' [ "{rename.get(name, name)}", {", ".join(str(n) for n in stats)} ],\n')
  246. except Exception as e:
  247. print(e)
  248. errors.append((name, e))
  249. outfile.write("];\n")
  250. print("Errors:", errors)