Freigeben über


Kultur- und medienübergreifende Recherche von Kunstwerken mit einer schnellen, bedingten Suche mit K-Nearest-Neighbors

Dieser Artikel dient als Leitlinie für die Suche von Übereinstimmungen über K-Nearest-Neighbors. Sie haben Code eingerichtet, der Abfragen von Kunstwerken über mehrere Kulturen und Medien hinweg ermöglicht, die sich im Metropolitan Museum of Art in NYC und dem Rijksmuseum in Amsterdam befindet.

Voraussetzungen

  • Fügen Sie Ihr Notebook an ein Lakehouse an. Wählen Sie auf der linken Seite Hinzufügen aus, um ein vorhandenes Lakehouse hinzuzufügen oder ein Lakehouse zu erstellen.

Übersicht über den BallTree

Die maßgebliche Struktur hinter dem KNN-Modell ist ein BallTree. Dabei handelt es sich um eine rekursive binäre Struktur, in der jeder Knoten (oder „Ball“) eine Partition der abzufragenden Datenpunkte enthält. Das Erstellen eines BallTrees umfasst das Zuweisen von Datenpunkten zu dem „Ball“, dessen Zentrum am nächsten liegt (in Bezug auf ein bestimmtes angegebenes Feature). Dadurch entsteht eine Struktur, die eine ähnliche Durchquerung wie bei einem Binärbaum und so eine K-Nearest-Neighbor-Suche an einem BallTree-Blatt ermöglicht.

Setup

Importieren Sie die erforderlichen Python-Bibliotheken, und bereiten Sie das Dataset vor.

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Unser Dataset stammt aus einer Tabelle mit Informationen zu Kunstwerken aus dem Met- und dem Rijks-Museum. Das Schema sieht folgendermaßen aus:

  • id: Ein eindeutiger Bezeichner für ein Kunstwerk
    • Beispiel-Met-ID: 388395
    • Beispiel-Rijks-ID: SK-A-2344
  • Title: Titel des Kunstwerks (Schreibweise wie in der Datenbank des Museums)
  • Artist: Erschaffer des Kunstwerks (Schreibweise wie in der Datenbank des Museums)
  • Thumbnail_Url: Speicherort einer JPEG-Miniaturansicht des Kunstwerks
  • Image_Url: Speicherort eines Bilds des Kunstwerks auf der Met/Rijks-Website
  • Culture: Kategorie der Kultur, unter die das Kunstwerk fällt
    • Beispiele für Kulturkategorien: latin american, egyptian usw.
  • Klassifizierung: Kategorie des Mediums, unter das das Kunstwerk fällt
    • Beispiele für Medienkategorien: woodwork, paintings usw.
  • Museum_Page: Link zum Kunstwerk auf der Met/Rijks-Website
  • Norm_Features: Einbettung des Kunstwerkbildes
  • Museum: Angabe, aus welchem Museum das Stück stammt
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Definieren der Kategorien, für die abgefragt werden sollen

Es werden zwei KNN-Modelle verwendet: eines für die Kultur und eines für das Medium.

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Definieren und Anpassen von ConditionalKNN-Modellen

Erstellen Sie ConditionalKNN-Modelle für die Medien- und Kulturspalte. Jedes Modell übernimmt eine Ausgabespalte, eine Featurespalte (Featurevektor), eine Wertespalte (Zellenwerte unter der Ausgabespalte) und eine Beschriftungsspalte (die Qualitätsbedingung für die jeweilige KNN).

medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definieren von Abgleichs- und Visualisierungsmethoden

Bereiten Sie nach dem anfänglichen Einrichtung von Dataset und Kategorie die Methoden vor, die die ConditionalKNN-Ergebnisse abfragen und visualisieren.

addMatches() erstellt einen Dataframe mit einigen wenigen Übereinstimmungen pro Kategorie.

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() ruft plot_img auf, um die besten Übereinstimmungen für jede Kategorie in einem Raster zu visualisieren.

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Alles zusammenfassen

Definieren Sie test_all(), um die Daten, CKNN-Modelle, die abzufragenden Kunst-ID-Werte und den Dateipfad zum Speicherort der Ausgabevisualisierung zu übernehmen. Die Medien- und Kulturmodelle wurden zuvor trainiert und geladen.

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png


def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Demo

Die folgende Zelle führt Batchabfragen unter Angabe gewünschter Bild-IDs und eines Dateinamens zum Speichern der Visualisierung aus.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")