Sdílet prostřednictvím


Zkoumání umění napříč kulturou a středními s rychlými, podmíněnými, k-nejbližšími sousedy

Tento článek slouží jako vodítko pro hledání shody prostřednictvím k-nejbližších sousedů. Nastavili jste kód, který umožňuje dotazy zahrnující kultury a média umění maskované z Metropolitního muzea umění v NYC a Rijksmuseum v Amsterdamu.

Požadavky

  • Připojte poznámkový blok k jezeru. Na levé straně vyberte Přidat a přidejte existující jezerní dům nebo vytvořte jezero.

Přehled BallTree

Struktura fungující za modelem KNN je BallTree, což je rekurzivní binární strom, kde každý uzel (neboli "míč") obsahuje oddíl bodů dat, které se mají dotazovat. Sestavení BallTree zahrnuje přiřazení datových bodů k "míči", jehož střed jsou nejblíže (s ohledem na určitou zadanou funkci), což vede ke struktuře, která umožňuje procházení podobné binárnímu stromu a umožňuje najít k-nejbližší sousedy na listu BallTree.

Nastavení

Naimportujte potřebné knihovny Pythonu a připravte datovou sadu.

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Naše datová sada pochází z tabulky obsahující informace o uměleckých dělech z muzea Met i Rijks. Schéma je následující:

  • ID: Jedinečný identifikátor pro kus umění
    • ID ukázkového metu: 388395
    • Id ukázky Rijks: SK-A-2344
  • Název: Název umělecké části, jak je napsané v databázi muzea
  • Umělec: umělecký umělec, jak je napsané v databázi muzea
  • Thumbnail_Url: Umístění miniatury obrázku ve formátu JPEG
  • Image_Url Umístění obrázku umělecké části hostované na webu Met/Rijks
  • Kultura: Kategorie kultury, pod kterou umělecká část spadá
    • Ukázkové jazykové kategorie: latinamerická, egyptská atd.
  • Klasifikace: Kategorie média, pod kterou umělecká část spadá
    • Ukázkové střední kategorie: dřevo, obrazy atd.
  • Museum_Page: Odkaz na dílo umění na webu Met/Rijks
  • Norm_Features: Vložení obrázku umělecké části
  • Muzeum: Určuje, ze kterého muzea pochází kus
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Definování kategorií, na které se má dotazovat

Používají se dva modely KNN: jeden pro jazykovou verzi a jeden pro střední.

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Definování a přizpůsobení modelů ConditionalKNN

Vytvoření modelů ConditionalKNN pro sloupce se střední i jazykovou verzí; každý model přebírá výstupní sloupec, obsahuje sloupec (vektor funkcí), sloupec hodnot (hodnoty buněk pod výstupním sloupcem) a sloupec popisku (kvalita, na které je příslušná KNN podmíněná).

medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definování odpovídajících a vizualizačních metod

Po počátečním nastavení datové sady a kategorie připravte metody, které budou dotazovat a vizualizovat výsledky podmíněné sítě KNN.

addMatches() vytvoří datový rámec s několika shodami pro každou kategorii.

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() volání plot_img k vizualizaci nejlepších shod pro každou kategorii do mřížky.

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Spojení všech součástí dohromady

Definujte test_all() , že chcete vzít data, modely CKNN, hodnoty ID umění, na které se mají dotazovat, a cestu k souboru, do které se má výstupní vizualizace uložit. Modely střední a jazykové verze byly dříve natrénovány a načteny.

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png


def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Ukázka

Následující buňka provádí dávkové dotazy s požadovaná ID obrázků a název souboru pro uložení vizualizace.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")