次の方法で共有


Python ユーザー定義テーブル関数 (UDTF)

重要

この機能は、Databricks Runtime 14.3 LTS 以降でパブリック プレビュー段階にあります。

ユーザー定義テーブル関数 (UDTF) を使用すると、スカラー値の代わりにテーブルを返す関数を登録できます。 各呼び出しから単一の結果値を返すスカラー関数とは異なり、各 UDTF は SQL ステートメントの FROM 句で呼び出され、テーブル全体を出力として返します。

各 UDTF 呼び出しは、0 個以上の引数を受け取ることができます。 これらの引数は、スカラー式または入力テーブル全体を表すテーブル引数にすることができます。

基本的な UDTF の構文

Apache Spark は、yield を使用して出力行を生成する必須の eval メソッドを持つ Python クラスとして Python UDTF を実装します。

クラスを UDTF として使用するには、PySpark udtf 関数をインポートする必要があります。 Databricks では、この関数をデコレーターとして使用し、returnType オプションを使用してフィールド名と型を明示的に指定することをお勧めしています (後のセクションで説明するようにクラスが analyze メソッドを定義している場合を除く)。

次の UDTF は、2 つの整数引数の固定リストを使用してテーブルを作成します。

from pyspark.sql.functions import lit, udtf

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, x: int, y: int):
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()
+----+-----+
| sum| diff|
+----+-----+
|   3|   -1|
+----+-----+

UDTF を登録する

UDTF はローカルの SparkSession に登録され、ノートブック レベルまたはジョブ レベルで分離されます。

UDTF を Unity Catalog のオブジェクトとして登録することはできず、UDTF を SQL ウェアハウスで使用することはできません。

関数 spark.udtf.register() を使用して、SQL クエリで使用できるように UDTF を 現在の SparkSession に登録することができます。 SQL 関数と Python UDTF クラスの名前を指定します。

spark.udtf.register("get_sum_diff", GetSumDiff)

登録済みの UDTF を呼び出す

登録すると、%sql マジック コマンドまたは spark.sql() 関数のいずれかを使用して、SQL で UDTF を使用できます。

spark.udtf.register("get_sum_diff", GetSumDiff)
spark.sql("SELECT * FROM get_sum_diff(1,2);")
%sql
SELECT * FROM get_sum_diff(1,2);

Apache Arrow を使用する

UDTF が少量のデータを入力として受け取り、大きなテーブルを出力する場合、Databricks では Apache Arrow の使用をお勧めしています。 UDTF を宣言するときに useArrow パラメーターを指定することで、Apache Arrow を有効にすることができます。

@udtf(returnType="c1: int, c2: int", useArrow=True)

可変個数引数リスト - *args と **kwargs

Python の *args 構文または **kwargs 構文を使用して、不特定数の入力値を処理するロジックを実装できます。

次の例では、引数の入力長と型を明示的に確認しながら、同じ結果が返されます。

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, *args):
        assert(len(args) == 2)
        assert(isinstance(arg, int) for arg in args)
        x = args[0]
        y = args[1]
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()

次に示すのは同じ例ですが、キーワード引数を使用しています。

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, **kwargs):
        x = kwargs["x"]
        y = kwargs["y"]
        yield x + y, x - y

GetSumDiff(x=lit(1), y=lit(2)).show()

登録時に静的スキーマを定義する

UDTF は、列名と型の順序付けられたシーケンスで構成される出力スキーマを持つ行を返します。 UDTF スキーマがすべてのクエリに対して常に同じであるべき場合は、@udtf デコレーターの後に静的な固定スキーマを指定することができます。 スキーマは、次のいずれかでなければなりません:StructType

StructType().add("c1", StringType())

または構造体型を表す DDL 文字列

c1: string

関数呼び出し時に動的スキーマを計算する

UDTF は、入力引数の値に応じて、呼び出しごとに出力スキーマをプログラムで計算することもできます。 これを行うには、特定の UDTF 呼び出しに渡される引数に対応する 0 個以上のパラメーターを受け取る analyze という静的メソッドを定義します。

analyze メソッドの各引数は、次のフィールドを含む AnalyzeArgument クラスのインスタンスです。

AnalyzeArgument クラスのフィールド 説明
dataType DataType としての入力引数の型。 入力テーブル引数の場合、これはテーブルの列を表す StructType です。
value Optional[Any] としての入力引数の値。 定数ではないテーブル引数またはリテラル スカラー引数の場合、これは None です。
isTable 入力引数が BooleanType としてのテーブルであるかどうか。
isConstantExpression 入力引数が BooleanType として定数のたたみ込みが可能な式かどうか。

analyze メソッドは、結果テーブルのスキーマを StructType として含む AnalyzeResult クラスのインスタンスのほか、いくつかのオプション フィールドを返します。 UDTF が入力テーブル引数を受け取る場合、AnalyzeResult には、後で説明するように、複数の UDTF 呼び出しにわたって入力テーブルの行をパーティション分割し、並べ替えるための要求された方法を含めることもできます。

AnalyzeResult クラスのフィールド 説明
schema StructType としての結果テーブルのスキーマ。
withSinglePartition すべての入力行を同じ UDTF クラス インスタンスに BooleanType として送信するかどうか。
partitionBy 空以外に設定されている場合、パーティション式の値の一意の組み合わせを持つすべての行が UDTF クラスの個別のインスタンスによって使用されます。
orderBy 空以外に設定されている場合、各パーティション内の行の順序を指定します。
select 空以外に設定されている場合、これは、UDTFが Catalyst に入力 TABLE 引数の列に対して評価するよう指定する一連の式です。 UDTF は、リスト内の名前ごとに 1 つの入力属性をリスト内の表示順に受け取ります。

この analyze の例では、入力文字列引数の単語ごとに 1 つの出力列を返します。

@udtf
class MyUDTF:
  @staticmethod
  def analyze(text: AnalyzeArgument) -> AnalyzeResult:
    schema = StructType()
    for index, word in enumerate(sorted(list(set(text.value.split(" "))))):
      schema = schema.add(f"word_{index}", IntegerType())
    return AnalyzeResult(schema=schema)

  def eval(self, text: str):
    counts = {}
    for word in text.split(" "):
      if word not in counts:
            counts[word] = 0
      counts[word] += 1
    result = []
    for word in sorted(list(set(text.split(" ")))):
      result.append(counts[word])
    yield result
['word_0', 'word_1']

将来の eval 呼び出しに状態を転送する

初期化を実行し、結果を同じ UDTF 呼び出しの将来の eval メソッド呼び出しに転送するための便利な場所として analyze メソッドを機能させることができます。

これを行うには、AnalyzeResult のサブクラスを作成し、analyze メソッドからサブクラスのインスタンスを返します。 次に、そのインスタンスを受け取るために __init__ メソッドへの引数を追加します。

この analyze の例では、固定出力スキーマを返しますが、将来の __init__ メソッド呼び出しで使用されるカスタム情報を結果のメタデータに追加します。

@dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
    buffer: str = ""

@udtf
class TestUDTF:
  def __init__(self, analyze_result=None):
    self._total = 0
    if analyze_result is not None:
      self._buffer = analyze_result.buffer
    else:
      self._buffer = ""

  @staticmethod
  def analyze(argument, _) -> AnalyzeResult:
    if (
      argument.value is None
      or argument.isTable
      or not isinstance(argument.value, str)
      or len(argument.value) == 0
    ):
      raise Exception("The first argument must be a non-empty string")
    assert argument.dataType == StringType()
    assert not argument.isTable
    return AnalyzeResultWithBuffer(
      schema=StructType()
        .add("total", IntegerType())
        .add("buffer", StringType()),
      withSinglePartition=True,
      buffer=argument.value,
    )

  def eval(self, argument, row: Row):
    self._total += 1

  def terminate(self):
    yield self._total, self._buffer

self.spark.udtf.register("test_udtf", TestUDTF)

spark.sql(
  """
  WITH t AS (
    SELECT id FROM range(1, 21)
  )
  SELECT total, buffer
  FROM test_udtf("abc", TABLE(t))
  """
).show()
+-------+-------+
| count | buffer|
+-------+-------+
|    20 |  "abc"|
+-------+-------+

出力行を生成する

eval メソッドは、入力テーブル引数の各行に対して 1 回 (テーブル引数が指定されていない場合は 1 回のみ) 実行され、最後に terminate メソッドが 1 回呼び出されます。 いずれのメソッドも、タプル、リスト、または pyspark.sql.Row オブジェクトを生成することで、結果スキーマに準拠した 0 行以上の行を出力します。

この例では、3 つの要素のタプルを指定して行を返します。

def eval(self, x, y, z):
  yield (x, y, z)

かっこを省略することもできます。

def eval(self, x, y, z):
  yield x, y, z

1 列のみの行を返すには、末尾にコンマを追加します。

def eval(self, x, y, z):
  yield x,

pyspark.sql.Row オブジェクトを生成することもできます。

def eval(self, x, y, z)
  from pyspark.sql.types import Row
  yield Row(x, y, z)

この例では、Python リストを使用して terminate メソッドから出力行を生成します。 この目的のために、UDTF 評価の前の手順で使用したクラスの内部に状態を格納することができます。

def terminate(self):
  yield [self.x, self.y, self.z]

UDTF にスカラー引数を渡す

リテラル値またはそれに基づく関数で構成される定数式としてスカラー引数を UDTF に渡すことができます。 次に例を示します。

SELECT * FROM udtf(42, group => upper("finance_department"));

テーブル引数を UDTF に渡す

Python UDF は、スカラー入力引数に加えて、入力テーブルを引数として受け取ることもできます。 単一の UDTF が 1 つのテーブル引数と複数のスカラー引数を受け取ることもできます。

したがって、TABLE(t) のように、TABLE キーワードに続けて適切なテーブル識別子をかっこで囲むことで、任意の SQL クエリで入力テーブルを渡すことができます。 あるいは、TABLE(SELECT a, b, c FROM t)TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key)) などのテーブルサブクエリを渡すこともできます。

入力テーブル引数は、eval メソッドへの pyspark.sql.Row 引数として表され、入力テーブル内の各行に対して eval メソッドが 1 回呼び出されます。 標準の PySpark 列フィールドの注釈を使用して、各行の列を操作できます。 次の例では、PySpark の Row 型を明示的にインポートし、渡されたテーブルを id フィールドでフィルター処理する方法を示します。

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="id: int")
class FilterUDTF:
    def eval(self, row: Row):
        if row["id"] > 5:
            yield row["id"],

spark.udtf.register("filter_udtf", FilterUDTF)

関数をクエリするには、TABLE SQL キーワードを使用します。

SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)));
+---+
| id|
+---+
|  6|
|  7|
|  8|
|  9|
+---+

関数呼び出しから入力行のパーティション分割を指定する

テーブル引数を使用して UDTF を呼び出す場合は、任意の SQL クエリが 1 つ以上の入力テーブル列の値に基づいて、複数の UDTF 呼び出しにわたって入力テーブルをパーティション分割することができます。

パーティションを指定するには、関数呼び出しで TABLE 引数の後に PARTITION BY 句を使用します。 これにより、パーティション分割列の値の一意の組み合わせを持つすべての入力行が UDTF クラスの 1 つのインスタンスによってのみ使用されることが保証されます。

単純な列参照に加えて、PARTITION BY 句は入力テーブル列に基づく任意の式も受け入れることに注目してください。 たとえば、文字列の LENGTH を指定したり、日付から月を抽出したり、2 つの値を連結したりできます。

PARTITION BY の代わりに WITH SINGLE PARTITION を指定して、UDTF クラスの 1 つのインスタンスだけがすべての入力行を使用する必要がある 1 つのパーティションのみを要求することもできます。

各パーティション内では、必要に応じて、UDTF の eval メソッドが入力行を使用するときに要求される入力行の順序を指定することができます。 これを行うには、上記の PARTITION BY 句または WITH SINGLE PARTITION 句の後に ORDER BY 句を指定します。

たとえば、次の UDTF があるとします。

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="a: string, b: int")
class FilterUDTF:
  def __init__(self):
    self.key = ""
    self.max = 0

  def eval(self, row: Row):
    self.key = row["a"]
    self.max = max(self.max, row["b"])

  def terminate(self):
    yield self.key, self.max

spark.udtf.register("filter_udtf", FilterUDTF)

入力テーブルに対して UDTF を呼び出すときに、複数の方法でパーティション分割オプションを指定することができます。

-- Create an input table with some example values.
DROP TABLE IF EXISTS values_table;
CREATE TABLE values_table (a STRING, b INT);
INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)";
SELECT * FROM values_table;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 2  |
| "abc" | 4  |
| "def" | 6  |
| "def" | 8  |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique value in the `a` column are processed by the same
-- instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 4  |
| "def" | 8  |
+-------+----+

-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
-- processed by the same instance of the UDTF class. Within each partition, the rows are ordered
-- by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1;
+-------+---+
|     a | b |
+-------+---+
| "def" | 8 |
+-------+---+
-- Query the UDTF with the input table as an argument and a directive to consider all the input
-- rows in one single partition such that exactly one instance of the UDTF class consumes all of
-- the input rows. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "def" | 8 |
+-------+----+

analyze メソッドから入力行のパーティション分割を指定する

SQL クエリで UDF を呼び出すときに入力テーブルをパーティション分割する上記の方法のそれぞれに対応する形で、UDTF の analyze メソッドで同じパーティション分割方法を自動的に指定する方法があることに注目してください。

  • UDTF を SELECT * FROM udtf(TABLE(t) PARTITION BY a) として呼び出す代わりに、analyze メソッドを更新してフィールド partitionBy=[PartitioningColumn("a")] を設定し、SELECT * FROM udtf(TABLE(t)) を使用して単に関数を呼び出すことができます。
  • 同様に、SQL クエリで TABLE(t) WITH SINGLE PARTITION ORDER BY b を指定する代わりに、analyze でフィールド withSinglePartition=trueorderBy=[OrderingColumn("b")] を設定した後、単に TABLE(t) を渡すことができます。
  • SQL クエリで TABLE(SELECT a FROM t) を渡す代わりに、analyze を使用して select=[SelectedColumn("a")] を設定した後、単に TABLE(t) を渡すことができます。

次の例で、analyze は、固定出力スキーマを返し、入力テーブルから列のサブセットを選択し、入力テーブルが date 列の値に基づいて複数の UDTF 呼び出しにわたってパーティション分割されるように指定します。

@staticmethod
def analyze(*args) -> AnalyzeResult:
  """
  The input table will be partitioned across several UDTF calls based on the monthly
  values of each `date` column. The rows within each partition will arrive ordered by the `date`
  column. The UDTF will only receive the `date` and `word` columns from the input table.
  """
  from pyspark.sql.functions import (
    AnalyzeResult,
    OrderingColumn,
    PartitioningColumn,
  )

  assert len(args) == 1, "This function accepts one argument only"
  assert args[0].isTable, "Only table arguments are supported"
  return AnalyzeResult(
    schema=StructType()
      .add("month", DateType())
      .add('longest_word", IntegerType()),
    partitionBy=[
      PartitioningColumn("extract(month from date)")],
    orderBy=[
      OrderingColumn("date")],
    select=[
      SelectedColumn("date"),
      SelectedColumn(
        name="length(word),
        alias="length_word")])