浏览代码

First pass at actually downloading and analyzing some images

Kirk Trombley 2 年之前
父节点
当前提交
fd38565a27
共有 3 个文件被更改,包括 106 次插入136 次删除
  1. 0 118
      ingest2.py
  2. 38 8
      tools/analyze.py
  3. 68 10
      tools/download.py

+ 0 - 118
ingest2.py

@@ -1,118 +0,0 @@
-"""
-Goals:
- + Use OKLab
- + Improved clustering logic
- + Parallel, in the same way as anim-ingest
- + Async requests for downloads
- + Include more info about the pokemon (form, display name, icon sprite source)
- + Include megas/gmax/etc, tagged so the UI can filter them
- * Include more images (get more stills from pokemondb + serebii)
- * Include shinies
- * Fallback automatically (try showdown animated, then showdown gen5, then pdb)
- * Filtering system more explicit and easier to work around
- * Output a record of ingest for auditing
- * Automatic retry of a partially failed ingest, using record
-"""
-# async def load_image(session: ClientSession, url: str) -> Image.Image:
-#   async with session.get(url) as res:
-#     res.raise_for_status()
-#     return Image.open(BytesIO(await res.read()))
-
-
-# async def load_all_images(urls: list[str]) -> tuple[list[Image.Image], list[Exception]]:
-#   async with ClientSession() as session:
-#     results = await asyncio.gather(
-#       *(load_image(session, url) for url in urls),
-#       return_exceptions=True
-#     )
-#   success = []
-#   errors = []
-#   for r in results:
-#     (success if isinstance(r, Image.Image) else errors).append(r)
-#   return success, errors
-
-
-# def get_urls(target: Pokemon, form: FormInfo) -> list[str]:
-#   lower_name = form.name.lower()
-#   return [
-#     f"https://play.pokemonshowdown.com/sprites/ani/{lower_name}.gif",
-#     f"https://play.pokemonshowdown.com/sprites/ani-back/{lower_name}.gif",
-#     f"https://play.pokemonshowdown.com/sprites/gen5/{lower_name}.png",
-#     f"https://play.pokemonshowdown.com/sprites/gen5-back/{lower_name}.png",
-#     f"https://img.pokemondb.net/sprites/home/normal/{lower_name}.png",
-#     # TODO other sources - want to make sure we never cross contaminate though...
-#     # if we pull the wrong form for something it will be a nightmare to debug
-#     # f"https://www.serebii.net/scarletviolet/pokemon/new/{target.num}-{???}.png"
-#     # f"https://www.serebii.net/pokemon/art/{target.num}-{???}.png"
-#   ]
-
-
-# async def set_data(target: Pokemon, seed=0) -> list[Exception]:
-#   all_errors = []
-#   for form in target.forms:
-#     print(f" #{target.num} - Ingesting Form: {form.name}")
-#     urls = get_urls(target, form)
-#     print(f"  #{target.num} - Attempting {len(urls)} potential sources")
-#     images, errors = await load_all_images(urls)
-#     all_errors.extend(errors)
-#     print(f"  #{target.num} - Loaded {len(images)} sources")
-#     try:
-#       pixels = np.concatenate([get_pixels(img) for img in images])
-#       print(f"  #{target.num} - Summarizing {len(pixels)} total pixels")
-#       total = calc_statistics(pixels)
-#       print(f"  #{target.num} - Begin clustering")
-#       clusters = find_clusters(pixels, seed=seed)
-#       print(f"  #{target.num} - End clustering, chose k={len(clusters)}")
-#       form.data = Data(total=total, clusters=clusters)
-#     except Exception as e:
-#       all_errors.append(e)
-#   return all_errors
-
-
-
-# async def ingest(pool_size: int, seed: int) -> tuple[list[str], list[str]]:
-#   computed = []
-#   errors = []
-#   loop = asyncio.get_event_loop()
-#   with ProcessPoolExecutor(pool_size) as exec:
-#     print(f"Ingesting #{start} - #{end}")
-#     for pkmn in pkdx[start - 1:end]:
-#       print(f"Ingesting #{pkmn.num}: {pkmn.species}...")
-#       new_errors = await set_data(pkmn, seed)
-#       loop.run_in_executor(exec, set_data, pkmn, seed)
-
-    
-
-#       computed.append(loop.run_in_executor(pool, ingest(p)))
-
-#   try:
-#     errors.extend(new_errors)
-#     print(f"Finished #{pkmn.num}: {len(new_errors)} error(s)")
-#     return json.dumps(asdict(pkmn))
-#   except Exception as e:
-#     print(e)
-#     errors.append(e)
-
-# if __name__ == "__main__":
-  # from sys import argv
-  # dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
-  # out_file = argv[2] if len(argv) > 2 else "data/database-latest.db"
-  # dex_span = argv[3] if len(argv) > 3 else "1-151"
-  # log_file = argv[4] if len(argv) > 4 else "errors-latest.log"
-  # set_seed = argv[5] if len(argv) > 5 else "20230304"
-
-  # start, end = map(int, dex_span.split("-", maxsplit=1))
-  # seed = int(set_seed)
-  # errors = []
-
-  # pkdx = list(load_pokedex(dex_file))
-  # loop = asyncio.new_event_loop()
-
-  # with open(log_file, "w") as log:
-  #   # TODO better logging
-  #   log.writelines(str(e) for e in errors)
-
-  # with open(out_file, "a") as db:
-  #   for _, line in computed:
-  #     db.write(line)
-  #     db.write("\n")

+ 38 - 8
tools/analyze.py

@@ -1,5 +1,6 @@
 import math
-from dataclasses import dataclass
+from pathlib import Path
+from dataclasses import dataclass, asdict
 from itertools import combinations
 
 import numpy as np
@@ -88,14 +89,14 @@ def calc_statistics(pixels: np.array) -> Stats:
     centroid=list(np.round(centroid, OUTPUT_PRECISION)),
     tilt=list(np.round(tilt, OUTPUT_PRECISION)),
     size=len(pixels),
-    variance=math.round(variance, OUTPUT_PRECISION),
-    chroma=math.round(chroma.mean(axis=0), OUTPUT_PRECISION),
-    hue=math.round(hue % (2 * math.pi), OUTPUT_PRECISION),
+    variance=round(variance, OUTPUT_PRECISION),
+    chroma=round(chroma.mean(axis=0), OUTPUT_PRECISION),
+    hue=round(hue % (2 * math.pi), OUTPUT_PRECISION),
     hex=oklab2hex(centroid),
   )
 
 
-def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
+def calc_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
   means, labels = max(
     (
       # Try k = 2, 3, and 4, and try a few times for each
@@ -116,7 +117,7 @@ def find_clusters(pixels: np.array, cluster_attempts=5, seed=0) -> list[Stats]:
   return [calc_statistics(pixels[labels == i]) for i in range(len(means))]
 
 
-def get_pixels(img: Image.Image) -> np.array:
+def get_srgb_pixels(img: Image.Image) -> np.array:
   rgb = []
   for fr in range(getattr(img, "n_frames", 1)):
     img.seek(fr)
@@ -125,8 +126,37 @@ def get_pixels(img: Image.Image) -> np.array:
       for r, g, b, a in img.convert("RGBA").getdata()
       if a > 0 and (r, g, b) != (0, 0, 0)
     ]
-  return srgb2oklab(np.array(rgb))
+  return np.array(rgb)
 
 
 if __name__ == "__main__":
-  print("TODO")
+  from sys import argv
+  dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
+  image_dir = argv[2] if len(argv) > 2 else "images"
+  seed = int(argv[3]) if len(argv) > 3 else 230308
+
+  import os
+  from collections import defaultdict
+
+  to_process = defaultdict(list)
+  for image_filename in os.listdir(image_dir):
+    form_name = image_filename.rsplit("-", maxsplit=1)[0]
+    to_process[form_name].append(Path(image_dir, image_filename))
+
+  # TODO multiproc
+  database = []
+  for form, image_files in to_process.items():
+    all_pixels = np.concatenate([
+      get_srgb_pixels(Image.open(fn)) for fn in image_files
+    ])
+    oklab = srgb2oklab(all_pixels)
+    database.append({
+      "name": form,
+      # TODO also get dex info - species, color, etc.
+      "total": asdict(calc_statistics(oklab)),
+      "clusters": [asdict(c) for c in calc_clusters(oklab, seed=seed)],
+    })
+
+  # TODO real output
+  import json
+  print(json.dumps(database, indent=2))

+ 68 - 10
tools/download.py

@@ -5,6 +5,7 @@ Manage the logic of downloading the pokedex and source images.
 import re
 import json
 import asyncio
+from pathlib import Path
 from dataclasses import dataclass, asdict
 from collections import defaultdict
 
@@ -164,19 +165,76 @@ def clean_dex(raw: dict) -> dict[int, Pokemon]:
   }
 
 
-async def main(dex_file: str):
-  # first download the pokedex
-  raw_dex = await load_pokedex()
-  # clean and reorganize it
-  dex = clean_dex(raw_dex)
-  # output dex for auditing
-  with open(dex_file, "w") as out:
-    json.dump({str(i): asdict(pkmn) for i, pkmn in dex.items()}, out, indent=2)
-  # TODO actually progress to images
+def get_showdown_urls(species: str, form: Form) -> list[tuple[str, str]]:
+  name = form.name.lower().replace("mega-y", "megay").replace("mega-x", "megax")
+  return [
+    (f"https://play.pokemonshowdown.com/sprites/ani/{name}.gif", "gif"),
+    (f"https://play.pokemonshowdown.com/sprites/ani-back/{name}.gif", "gif"),
+    (f"https://play.pokemonshowdown.com/sprites/gen5/{name}.png", "png"),
+    (f"https://play.pokemonshowdown.com/sprites/gen5-back/{name}.png", "png"),
+  ]
+
+
+async def download(session: ClientSession, url: str, filename: str) -> tuple[str, Exception | bool]:
+  if Path(filename).is_file():
+    return url, False
+  try:
+    async with session.get(url) as res:
+      res.raise_for_status()
+      with open(filename, "wb") as out:
+        out.write(await res.read())
+  except Exception as ex:
+    return url, ex
+  return url, True
+
+
+async def download_all(pkmn: Pokemon, image_dir: str) -> dict[str, dict[str, Exception | bool]]:
+  results = defaultdict(dict)
+  async with ClientSession() as session:
+    for form in pkmn.forms:
+      urls = []
+      urls += get_showdown_urls(pkmn.species, form)
+      # TODO more
+      results[form.name].update(await asyncio.gather(*[
+        download(session, url, f"{image_dir}/{form.name}-{i}.{ext}")
+        for i, (url, ext) in enumerate(urls)
+      ]))
+  return results
+
+
+async def main(dex_file: str, image_dir: str):
+  if Path(dex_file).is_file():
+    with open(dex_file) as infile:
+      loaded = json.load(infile)
+    dex = {
+      int(num): Pokemon(
+        num=entry["num"],
+        species=entry["species"],
+        forms=[Form(**f) for f in entry["forms"]],
+      ) for num, entry in loaded.items()
+    }
+  else:
+    # first download the pokedex
+    raw_dex = await load_pokedex()
+    # clean and reorganize it
+    dex = clean_dex(raw_dex)
+    # output dex for auditing and reloading
+    with open(dex_file, "w") as out:
+      json.dump({
+        str(i): asdict(pkmn)
+        for i, pkmn in dex.items()
+      }, out, indent=2)
+
+  Path(image_dir).mkdir(parents=True, exist_ok=True)
+  log = await download_all(dex[286], image_dir)
+  for url, result in log.items():
+    print(url, "-", str(result))
+  # TODO actually get all images
 
 
 if __name__ == "__main__":
   from sys import argv
   dex_file = argv[1] if len(argv) > 1 else "data/pokedex.json"
+  image_dir = argv[2] if len(argv) > 2 else "images"
 
-  asyncio.run(main(dex_file))
+  asyncio.run(main(dex_file, image_dir))