빠르고 조건부인 k와 가장 가까운 이웃을 사용하여 문화와 매체를 아우르는 예술을 탐구하세요
이 문서는 k와 가장 가까운 이웃을 통한 일치 찾기에 대한 안내를 담고 있습니다. 뉴욕 메트로폴리탄 미술관과 암스테르담 국립 미술관에서 수집한 문화 및 예술 매체와 관련된 쿼리를 허용하는 코드를 설정했습니다.
필수 조건
- 레이크하우스에 Notebook을 첨부합니다. 왼쪽에서 추가를 선택하여 기존 레이크하우스를 추가하거나 레이크하우스를 만듭니다.
BallTree 개요
KNN 모델 뒤에서 작동하는 구조는 각 노드(또는 "볼")가 쿼리할 데이터 요소의 파티션을 포함하는 재귀 이진 트리인 BallTree입니다. BallTree를 빌드하려면 (지정된 특정 기능과 관련하여) 중심이 가장 가까운 "볼"에 데이터 요소를 할당하여 이진 나무와 같은 순회를 허용하고 BallTree 리프에서 k-가장 가까운 이웃을 찾는 데 빌려주는 구조가 생성됩니다.
설정
필요한 Python 라이브러리를 가져오고 데이터 세트를 준비합니다.
from synapse.ml.core.platform import *
if running_on_binder():
from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO
import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()
우리의 데이터 세트는 Met 및 Rijks 박물관의 작품 정보가 들어있는 테이블에서 제공됩니다. 스키마는 다음과 같습니다.
- id: 아트의 고유 식별자
- 샘플 Met ID: 388395
- 샘플 Rijks ID: SK-A-2344
- 제목: 박물관 데이터베이스에 기록된 작품의 제목
- 예술가: 박물관의 데이터베이스에 기록된 작품의 아티스트
- Thumbnail_Url: 작품의 JPEG 썸네일 위치
- Image_Url Met/Rijks 웹 사이트에서 호스팅되는 예술 작품 이미지의 위치
- 문화: 예술 작품이 속하는 문화의 범주
- 샘플 문화권 범주: 라틴 아메리카, 이집트 등.
- 분류: 작품의가 속하는 매체의 범주
- 샘플 중간 범주: 목공, 그림 등.
- Museum_Page: Met/Rijks 웹 사이트의 예술 작품에 연결
- Norm_Features: 작품의 이미지 포함
- 박물관: 작품이 유래한 박물관을 지정합니다.
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))
쿼리할 범주 정의
두 개의 KNN 모델이 사용됩니다. 하나는 문화권용이고 다른 하나는 매체용입니다.
# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs', "metalwork",
# "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]
mediums = ["paintings", "glass", "ceramics"]
# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
# 'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
# 'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
# 'spanish', 'swiss', 'various']
cultures = ["japanese", "american", "african (general)"]
# Uncomment the above for more robust and large scale searches!
classes = cultures + mediums
medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}
small_df = df.where(
udf(
lambda medium, culture, id_val: (medium in medium_set)
or (culture in culture_set)
or (id_val in selected_ids),
BooleanType(),
)("Classification", "Culture", "id")
)
small_df.count()
ConditionalKNN 모델 정의 및 맞춤
중간 및 문화권 열 모두에 대한 ConditionalKNN 모델을 만듭니다. 각 모델은 출력 열, 기능 열(기능 벡터), 값 열(출력 열 아래의 셀 값) 및 레이블 열(각 KNN이 조건부로 지정된 품질)을 사용합니다.
medium_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Classification")
.fit(small_df)
)
culture_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Culture")
.fit(small_df)
)
일치 및 시각화 메서드 정의
초기 데이터 세트 및 범주 설정 후에 조건부 KNN의 결과를 쿼리하고 시각화하는 메서드를 준비합니다.
addMatches()
(은)는 범주당 몇 개의 일치 항목이 있는 데이터 프레임을 만듭니다.
def add_matches(classes, cknn, df):
results = df
for label in classes:
results = cknn.transform(
results.withColumn("conditioner", array(lit(label)))
).withColumnRenamed("Matches", "Matches_{}".format(label))
return results
plot_urls()
(은)는 각 범주에 대한 상위 일치 항목을 표로 시각화하기 위해 plot_img
(을)를 호출합니다.
def plot_img(axis, url, title):
try:
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
axis.imshow(img, aspect="equal")
except:
pass
if title is not None:
axis.set_title(title, fontsize=4)
axis.axis("off")
def plot_urls(url_arr, titles, filename):
nx, ny = url_arr.shape
plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
fig, axes = plt.subplots(ny, nx)
# reshape required in the case of 1 image query
if len(axes.shape) == 1:
axes = axes.reshape(1, -1)
for i in range(nx):
for j in range(ny):
if j == 0:
plot_img(axes[j, i], url_arr[i, j], titles[i])
else:
plot_img(axes[j, i], url_arr[i, j], None)
plt.savefig(filename, dpi=1600) # saves the results as a PNG
display(plt.show())
모든 항목 요약
데이터, CKNN 모델, 쿼리할 아트 ID 값 및 출력 시각화를 저장할 파일 경로를 가져오도록 test_all()
(을)를 정의합니다. 매체 및 문화 모델은 이전에 학습 및 로드되었습니다.
# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png
def test_all(data, cknn_medium, cknn_culture, test_ids, root):
is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
test_df = data.where(is_nice_obj("id"))
results_df_medium = add_matches(mediums, cknn_medium, test_df)
results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)
results = results_df_culture.collect()
original_urls = [row["Thumbnail_Url"] for row in results]
culture_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in cultures
]
culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")
medium_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in mediums
]
medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")
return results_df_culture
데모
다음 셀은 원하는 이미지 ID와 시각화를 저장하는 파일 이름을 지정하여 일괄 처리된 쿼리를 수행합니다.
# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")