Utforska konst över kultur och medium med snabba, villkorliga, k-närmaste grannar
Den här artikeln fungerar som en riktlinje för matchning via k-nearest-neighbors. Du ställer in kod som tillåter frågor som involverar kulturer och medier av konst som samlats in från Metropolitan Museum of Art i NYC och Rijksmuseum i Amsterdam.
Förutsättningar
- Bifoga anteckningsboken till ett sjöhus. Till vänster väljer du Lägg till för att lägga till ett befintligt sjöhus eller skapa ett sjöhus.
Översikt över BallTree
Strukturen som fungerar bakom KNN-modellen är en BallTree, som är ett rekursivt binärt träd där varje nod (eller "boll") innehåller en partition av de datapunkter som ska frågas. Att skapa en BallTree innebär att tilldela datapunkter till den "boll" vars mitt de är närmast (med avseende på en viss angiven funktion), vilket resulterar i en struktur som tillåter binärt trädliknande bläddring och lämpar sig för att hitta k-närmaste grannar på ett BallTree-löv.
Ställ in
Importera nödvändiga Python-bibliotek och förbered datauppsättningen.
from synapse.ml.core.platform import *
if running_on_binder():
from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO
import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()
Vår datauppsättning kommer från en tabell som innehåller konstverksinformation från både Met- och Rijks-museerna. Schemat är följande:
- id: En unik identifierare för ett konstverk
- Exempel på Met-ID: 388395
- Exempel på Rijks-ID: SK-A-2344
- Titel: Konststyckets titel, enligt skriven i museets databas
- Konstnär: Konststyckekonstnär, som skrivits i museets databas
- Thumbnail_Url: Platsen för en JPEG-miniatyrbild av konstverket
- Image_Url Plats för en bild av konstverket som finns på Met/Rijks webbplats
- Kultur: Kulturkategori som konstverket faller under
- Exempel på kulturkategorier: latinamerikansk, egyptisk osv.
- Klassificering: Kategori av medium som konstverket faller under
- Exempel på medelstora kategorier: träarbete, målningar osv.
- Museum_Page: Länk till konstverket på Met/Rijks webbplats
- Norm_Features: Inbäddning av konststyckets bild
- Museum: Anger vilket museum verket kommer från
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))
Definiera kategorier som ska frågas om
Två KNN-modeller används: en för kultur och en för medium.
# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs', "metalwork",
# "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]
mediums = ["paintings", "glass", "ceramics"]
# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
# 'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
# 'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
# 'spanish', 'swiss', 'various']
cultures = ["japanese", "american", "african (general)"]
# Uncomment the above for more robust and large scale searches!
classes = cultures + mediums
medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}
small_df = df.where(
udf(
lambda medium, culture, id_val: (medium in medium_set)
or (culture in culture_set)
or (id_val in selected_ids),
BooleanType(),
)("Classification", "Culture", "id")
)
small_df.count()
Definiera och anpassa ConditionalKNN-modeller
Skapa ConditionalKNN-modeller för både mellan- och kulturkolumnerna. varje modell tar in en utdatakolumn, funktionskolumn (funktionsvektor), värdekolumn (cellvärden under utdatakolumnen) och etikettkolumn (den kvalitet som respektive KNN är villkorad på).
medium_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Classification")
.fit(small_df)
)
culture_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Culture")
.fit(small_df)
)
Definiera matchnings- och visualiseringsmetoder
Efter den inledande datauppsättningen och kategorikonfigurationen förbereder du metoder som frågar efter och visualiserar det villkorliga KNN-resultatet.
addMatches()
skapar en dataram med en handfull matchningar per kategori.
def add_matches(classes, cknn, df):
results = df
for label in classes:
results = cknn.transform(
results.withColumn("conditioner", array(lit(label)))
).withColumnRenamed("Matches", "Matches_{}".format(label))
return results
plot_urls()
anrop plot_img
för att visualisera toppmatchningar för varje kategori i ett rutnät.
def plot_img(axis, url, title):
try:
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
axis.imshow(img, aspect="equal")
except:
pass
if title is not None:
axis.set_title(title, fontsize=4)
axis.axis("off")
def plot_urls(url_arr, titles, filename):
nx, ny = url_arr.shape
plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
fig, axes = plt.subplots(ny, nx)
# reshape required in the case of 1 image query
if len(axes.shape) == 1:
axes = axes.reshape(1, -1)
for i in range(nx):
for j in range(ny):
if j == 0:
plot_img(axes[j, i], url_arr[i, j], titles[i])
else:
plot_img(axes[j, i], url_arr[i, j], None)
plt.savefig(filename, dpi=1600) # saves the results as a PNG
display(plt.show())
Färdigställa allt
Definiera test_all()
för att ta in data, CKNN-modeller, de konst-ID-värden som du vill fråga efter och filsökvägen för att spara utdatavisualiseringen till. Modellerna medium och kultur har tidigare tränats och lästs in.
# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png
def test_all(data, cknn_medium, cknn_culture, test_ids, root):
is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
test_df = data.where(is_nice_obj("id"))
results_df_medium = add_matches(mediums, cknn_medium, test_df)
results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)
results = results_df_culture.collect()
original_urls = [row["Thumbnail_Url"] for row in results]
culture_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in cultures
]
culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")
medium_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in mediums
]
medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")
return results_df_culture
Demo
Följande cell utför batchbaserade frågor med önskade bild-ID:er och ett filnamn för att spara visualiseringen.
# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")