Udostępnij za pośrednictwem


Eksplorowanie sztuki w całej kulturze i średniej z szybkimi, warunkowymi, k najbliższymi sąsiadami

Ten artykuł stanowi wytyczne dotyczące znajdowania dopasowań za pośrednictwem k najbliższych sąsiadów. Konfigurujesz kod, który umożliwia wykonywanie zapytań obejmujących kultury i medium sztuki zebrane z Metropolitan Museum of Art w Nowym Jorku i Rijks apk w Amsterdamie.

Wymagania wstępne

  • Dołącz notes do magazynu lakehouse. Po lewej stronie wybierz pozycję Dodaj , aby dodać istniejący obiekt lakehouse lub utworzyć jezioro.

Omówienie obiektu BallTree

Struktura działająca za modelem KNN to BallTree, czyli cykliczne drzewo binarne, w którym każdy węzeł (lub "piłka") zawiera partycję punktów danych do odpytowania. Tworzenie obiektu BallTree polega na przypisaniu punktów danych do "piłki", której środek znajduje się najbliżej (w odniesieniu do określonej funkcji), co powoduje utworzenie struktury, która umożliwia przechodzenie podobne do drzewa binarnego i nadaje się do znalezienia najbliższych sąsiadów w liściu BallTree.

Konfiguracja

Zaimportuj niezbędne biblioteki języka Python i przygotuj zestaw danych.

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Nasz zestaw danych pochodzi z tabeli zawierającej informacje o sztuce zarówno z muzeów Met, jak i Rijks. Schemat jest następujący:

  • id: unikatowy identyfikator dzieła sztuki
    • Przykładowy identyfikator met: 388395
    • Przykładowy identyfikator Rijks: SK-A-2344
  • Tytuł: Tytuł sztuki, napisany w bazie danych muzeum
  • Artysta: Artysta sztuki, napisany w bazie danych muzeum
  • Thumbnail_Url: Lokalizacja miniatury JPEG dzieła sztuki
  • Image_Url Lokalizacja obrazu dzieła sztuki hostowanego na stronie internetowej Met/Rijks
  • Kultura: Kategoria kultury, pod którą znajduje się kawałek sztuki
    • Kategorie kultury przykładowej: latynoamerykańska, egipska itp.
  • Klasyfikacja: Kategoria średniej, pod którą znajduje się kawałek sztuki
    • Przykładowe kategorie średnie: stolarka, obrazy itp.
  • Museum_Page: Link do dzieła sztuki na stronie internetowej Met/Rijks
  • Norm_Features: Osadzanie obrazu utworu sztuki
  • Muzeum: Określa, z którego muzeum pochodzi kawałek
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Definiowanie kategorii do odpytowania

Używane są dwa modele KNN: jeden dla kultury i jeden dla średniej.

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Definiowanie i dopasowywanie modeli ConditionalKNN

Tworzenie modeli ConditionalKNN dla kolumn średniego i kulturowego; każdy model przyjmuje kolumnę wyjściową, funkcje kolumny (wektor funkcji), kolumnę wartości (wartości komórek w kolumnie wyjściowej) i kolumnę etykiety (jakość, na której jest spełniony odpowiednia nazwa KNN).

medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definiowanie pasujących i wizualizowania metod

Po wstępnej konfiguracji zestawu danych i kategorii przygotuj metody, które będą wykonywać zapytania i wizualizować wyniki warunkowej nazwy KNN.

addMatches() tworzy ramkę danych z kilkoma dopasowaniami na kategorię.

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() wywołania plot_img w celu wizualizacji pierwszych dopasowań dla każdej kategorii do siatki.

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Zebranie wszystkich elementów

Zdefiniuj test_all() , aby pobrać dane, modele CKNN, wartości identyfikatora sztuki do wykonania zapytania oraz ścieżkę pliku w celu zapisania wizualizacji wyjściowej. Modele średnie i kulturowe zostały wcześniej wytrenowane i załadowane.

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png


def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Wersja demonstracyjna

Poniższa komórka wykonuje zapytania wsadowe z identyfikatorami żądanych obrazów i nazwą pliku w celu zapisania wizualizacji.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")