首頁/ 汽車/ 正文

如何在OneFlow中新增運算元

撰文|姚遲、鄭澤康

本文將以開發一個 leaky_relu(準確說是 leaky_relu_yzh op,因為 master 分支的 leaky_relu 組合了其它知識點)為例介紹如何在 OneFlow 中新增運算元(

https://github。com/Oneflow-Inc/oneflow/pull/8350

)。

1

背景

op 與 kernel

op 與 kernel 是兩個有關聯的概念。op 是邏輯上的運算元,包含 OneFlow Compiler 在構建計算圖時所需要的必要資訊,如輸入、輸出形狀,哪些張量需要自動求導等資訊。有了 op 中的資訊,OneFlow Compiler 就可以構建計算圖並依據計算圖做資源申請、構建等操作(如根據張量的輸入輸出大小申請記憶體), 但是 op 中不包含具體的處理資料的邏輯。

在真正需要處理資料時,OneFlow Runtime 會啟動 kernel 完成計算,所以 kernel 中包含了具體處理資料的邏輯。對於一個邏輯上的 op,OneFlow Runtime 會根據資料型別、硬體裝置(比如是 CPU 還是 CUDA)的具體情況,選擇啟動不同的 kernel。

OneFlow 中的系統 op 與 user op

在 OneFlow 系統中存在兩類運算元(op):系統 op 和 user op。

系統 op 定義在:oneflow/core/operator/ 目錄, 對應的 kernel 實現在:oneflow/core/kernel 目錄。系統 op 是對構圖、流水等系統性能較為關鍵的一些 op。

除極少數 op 屬於系統 op 外,大多數 op 都是 user op,這些 user op 和使用者模型業務邏輯相關。OneFlow user op 的定義及 kernel 實現分別在 oneflow/user/ops 和 oneflow/user/kernels 目錄下。

目前 OneFlow 已實現了豐富的運算元庫,但是當已有的運算元庫無法滿足搭建模型的需求時,就需要新增運算元。本文介紹的新增運算元指的是新增 user op。

ODS 與 TableGen

Table

Gen(

https://llvm。org/docs/TableGen/index。html

) 是一個程式碼生成工具,簡單而言,它讀取並解析一個

。td

格式(語法接近 C++ 模板)的檔案,然後交給 TableGen 後端

https://llvm。org/docs/TableGen/BackEnds。html

生成另外格式的語言。

MLIR 基於 TableGen 制定了一套運算元定義規範ODS

https://mlir。llvm。org/docs/OpDefinitions/

以及對應的後端 OpDefinitionsGen

https://github。com/llvm/llvm-project/blob/main/mlir/tools/mlir-tblgen/OpDefinitionsGen。cpp。

OneFlow 在 ODS 的基礎上,實現了 TableGen OneFlow 後端

https://github。com/Oneflow-Inc/oneflow/tree/master/tools/oneflow-tblgen

,並使用它來定義 OneFlow user op。

因此,OneFlow 的 user op 定義寫在 OneFlowUserOps。td 檔案中。

2

開發 op

在 OneFlow 中開發一個新的 user op,主要分為以下4步:

定義 op

實現 kernel 計算邏輯

匯出 functional 介面

實現用於求導的反向邏輯

定義 op

定義 op 指的是,對 op 的名稱,op 的輸入、輸出資料型別和 op 的屬性進行宣告。OneFlow 遵循 MLIR 的 ODS(Operation Definition Specification)

https://mlir。llvm。org/docs/OpDefinitions/

實現了自己的 MLIR OneFlow Dialect。在運算元定義方面,這樣做的好處是,各種推導函式和序列化/反序列化的介面都可以委託給 ODS,降低了人工手寫出錯的機率,後續最佳化、格式轉化等流程可以更靈活。

定義一個 OneFlow user op,主要包括 5 個部分,分別是:

op class

輸入 input

輸出 output

屬性 attrs

匯出並實現推導介面

op class

可以在

oneflow/ir/include/OneFlow/OneFlowUserOps。td

檢視 op 定義的原始碼。

def

關鍵字開頭定義一個 op,該 op 繼承

OneFlow_BaseOp

,同時指定

OneFlow_BaseOp

的模版引數。模版引數依次為 op type name、Trait (

https://mlir。llvm。org/docs/Traits/

)列表。

def OneFlow_LeakyReluYZHOp : OneFlow_BaseOp<“leaky_relu_yzh”, [NoSideEffect, DeclareOpInterfaceMethods]> {//。。。}

其中 “

leaky_relu_yzh

” 是指定的 op type name。每個 op 都需要指定一個全域性唯一的 op type name 作為全域性識別符號。

第二個模板引數是一個 list(

[。。。]

),其中的每一項都是一個 Trait,OneFlow 中常用的有:

NoSideEffect 表示該運算元無副作用(即不會改變記憶體、網路、管道、磁碟等的系統狀態),這個特性可以指導某些最佳化操作

NoGrad 表示該運算元在數學上沒有梯度(不可導)

CpuOnly 表示該運算元只支援在 CPU 裝置上執行

SupportNonContiguous 表示該運算元是否支援 NonContiguous 張量(關於 Contiguous Tensor 的概念,可以參考 PyTorch Internals 中的相關內容 )

輸入 input 與輸出 output

透過重寫

input

域來定義 op 的輸入,比如

// 一個輸入 xlet input = (ins OneFlow_Tensor:$x);

定義了一個輸入張量

x

。輸入的格式為 輸入型別:

$name

輸入型別目前包括:

OneFlow_Tensor

Variadic

:指可變 tensor,比如 concat op,支援 concat 可變個數的 tensor。

Optional

:表示這個 tensor 是可選的,既可以有也可以沒有,比如 conv op 中的 add_output。

一個 op 也可以定義多個輸入,比如:

// 兩個輸入:a, b let input = (ins OneFlow_Tensor:$a, OneFlow_Tensor:$b );

透過重寫

output

域來定義 op 的輸出,比如下面定義了 2 個輸出張量:

let output = (outs OneFlow_Tensor:$out0, OneFlow_Tensor:$out1);

屬性 attrs

透過重寫

attrs

域定義 op 的屬性,比如定義 dropout (

https://oneflow。readthedocs。io/en/master/functional。html#oneflow。nn。functional。dropout

)中的

rate

屬性:

let attrs = (ins DefaultValuedAttr:$rate );

它表示名為

$rate

的型別是

F32Attr

,預設值是

0。

。這裡也可以不指定預設值:

let attrs = (ins F32Attr:$rate );

I32Attr、F32Attr、BoolAttr、StrAttr、I32ArrayAttr 等常見基礎資料型別定義在 OpBase。td

https://github。com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/OpBase。td#L1077-L1086

)中。

OneFlow 自定義資料型別,如 ShapeAttr、DTArrayAttr 等定義在 OneFlowBase。td

https://github。com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/include/OneFlow/OneFlowBase。td#L27-L35

)中。

匯出並實現推導介面

還有一些其它域,用於指定是否生成對應的介面。這些介面往往是構建計算圖過程中的推導介面。

比如 shape 推導(根據輸入的 shape 推導輸出的推導)、data type 推導、SBP 推導等。

OneFlow-TableGen 僅負責生成這些函式的介面,開發者需要在其自動生成的 cpp 檔案中實現這些介面。預設情況不會生成下列任何介面,開發者需要顯式指定需要生成哪些介面。

let has_check_fn = 1; // 生成屬性檢查介面 let has_logical_tensor_desc_infer_fn = 1; // 生成 logical shape 推導介面 let has_physical_tensor_desc_infer_fn = 1; // 生成 physical shape 推導介面 let has_get_sbp_fn = 1; // 生成 get sbp 介面 let has_sbp_signature_infer_fn = 1; // 生成 sbp signature 推導介面,未來會移除,推薦使用 has_nd_sbp_infer_fn let has_data_type_infer_fn = 1; // 生成 data type 推導介面 let has_device_and_stream_infer_fn = 1; // 生成 device 推導介面 let has_input_arg_modify_fn = 1; // 生成輸入 modify 介面,比如設定 is_mutable、requires_grad(用於Lazy)等 let has_output_arg_modify_fn = 1; // 生成輸出 modify 介面,比如設定 is_mutable、requires_grad(用於Lazy)等 let has_output_blob_time_shape_infer_fn = 1; // 生成輸出 time shape 推導介面 let has_nd_sbp_infer_fn = 1; // 生成 nd sbp 推導介面

一般常用的是下面幾個:

let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; let has_data_type_infer_fn = 1; let has_get_sbp_fn = 1;

瞭解完上面這些概念和用法後,可以開始修改

oneflow/ir/include/OneFlow/OneFlowUserOps。td

檔案。

leaky_relu_yzh op 完整的定義見 這裡

https://github。com/Oneflow-Inc/oneflow/blob/7ab4b0f08c86a6f8af08b44daa510725942288fb/oneflow/ir/include/OneFlow/OneFlowUserOps。td#L8418-L8433

OneFlowUserOps。td

中新增Op定義之後,重新 make 後會自動在 build 目錄下的

oneflow/core/framework/

目錄下生成檔案以下幾個檔案:

op_generated。h

:由解析

。td

檔案生成的 op C++ 類

op_generated。cpp

:由解析

。td

檔案生成的 op 註冊程式碼(包含呼叫 REGISTER_USER_OP 宏的程式碼)

之後需要做的就是在 oneflow/user/ops (

https://github。com/Oneflow-Inc/oneflow/tree/master/oneflow/user/ops

)目錄下新加一個 cpp 檔案,用於實現 op 的介面。

比如 leaky_relu_yzh 對應的文在 oneflow/user/ops/leaky_relu_yzh_op。cpp

https://github。com/Oneflow-Inc/oneflow/blob/7ab4b0f08c86a6f8af08b44daa510725942288fb/oneflow/user/ops/leaky_relu_yzh_op。cpp#L21-L79

),

實現了推導邏輯張量、推導物理張量、推導 SBP 資訊以及推導輸出資料型別各介面。

實現 Kernel 邏輯

op 的計算支援多種裝置(如 CPU、GPU、DCU 等),所以要分別實現計算邏輯。

相關程式碼:

Leaky ReLU CPU Kernel

https://github。com/Oneflow-Inc/oneflow/blob/7ab4b0f08c86a6f8af08b44daa510725942288fb/oneflow/user/kernels/leaky_relu_yzh_kernel。cpp

Leaky ReLU GPU KernelCPU

https://github。com/Oneflow-Inc/oneflow/blob/7ab4b0f08c86a6f8af08b44daa510725942288fb/oneflow/user/kernels/leaky_relu_yzh_kernel。cu

計算邏輯

templateclass CpuLeakyReluYZHKernel final : public user_op::OpKernel { public: CpuLeakyReluYZHKernel() = default; ~CpuLeakyReluYZHKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(“x”, 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(“y”, 0); const int32_t elem_cnt = x->shape()。elem_cnt(); const T* x_ptr = x->dptr(); T* y_ptr = y->mut_dptr(); const auto alpha = ctx->Attr(“alpha”); FOR_RANGE(int32_t, i, 0, elem_cnt) { y_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha * x_ptr[i]; } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }};

在 OneFlow 中實現 kernel, 必須定義一個繼承自

oneflow::user_op::OpKernel

的類,並重寫其中的虛擬函式。在以上程式碼中,重寫了

Compute

AlwaysComputeWhenAllOutputsEmpty

兩個虛擬函式,它們的意義分別是:

Compute

必須重寫,在其中實現具體的運算邏輯

AlwaysComputeWhenAllOutputsEmpty

必須重寫,對於絕大多數 op 而言,直接返回 false 即可。對於極少數內部需要維護狀態,即使輸出為空也需要呼叫 kernel 進行計算的 op 而言,應該返回 true

Compute

方法中透過呼叫

user_op::KernelComputeContext* ctx

中的介面,可以獲取輸入張量、輸出張量、attr 具體的資料,再按照運算元的演算法邏輯對它們進行處理。以下是對

CpuLeakyReluKernel::Compute

處理邏輯的解讀:

首先取得 “

x

”,“

y

” 兩個 Tensor。傳入

Tensor4ArgNameAndIndex

的字串要和之前在

OneFlowUserOps。td

設定的名稱一致

獲取

x

的元素個數,以便後續用於

for

迴圈進行計算

獲取屬性

alpha

進入次數為

elem_cnt

for

迴圈,將結果寫入

註冊 Kernel

實現 kernel 類後,需要呼叫

REGISTER_USER_KERNEL

註冊。

#define REGISTER_CPU_LEAKY_RELU_YZH_KERNEL(dtype) \ REGISTER_USER_KERNEL(“leaky_relu_yzh”) \ 。SetCreateFn>() \ 。SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType(“y”, 0) == GetDataType::value));

這裡會呼叫REGISTER_USER_KERNEL宏,包括以下資訊:

op type name:為哪個 op 註冊 kernel

SetCreateFn()

:該模板方法的模板引數

T

,就是我們實現的 kernel 類,OneFlow Runtime 將使用它建立 kernel 物件。

SetIsMatchedHob

:因為一個 op 可能有多個 kernel,要想根據物理裝置及資料格式的不同而選擇不同的 kernel 進行計算,就需要呼叫 SetIsMatchedHob 進行設定。該方法接受一個表示式,表示式為 true 時,OneFlow 將呼叫該 kernel 完成計算。以上程式碼的匹配邏輯是:當硬體裝置為

cpu

,且

y

的資料型別和

dtype

一致時,選擇呼叫註冊的 kernel 類(

CpuLeakyReluYZHKernel

)。

GPU 計算邏輯

CUDA 程式設計基礎知識入門可以參考:

影片:CUDA 的由來(

https://www。bilibili。com/video/BV1Mb4y1p7BG

影片:CUDA 的入門小程式(

https://www。bilibili。com/video/BV1bF411s76k

影片:執行緒層級(

https://www。bilibili。com/video/BV1MZ4y127Sq

不過以上的影片都無法替代自己認真學習官方資料:CUDA C Programming Guide(

https://docs。nvidia。com/cuda/cuda-c-programming-guide/index。html

瞭解了 CUDA 的基礎知識,就不難理解 leaky_relu CUDA 版本的實現。

首先定義了 leaky_relu 前向運算的 CUDA 核函式

template__global__ void LeakyReluForwardGpu(const int n, const float alpha, const T* x, T* y) { CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] > 0 ? x[i] : x[i] * alpha; }}

其中呼叫了宏 CUDA_1D_KERNEL_LOOP (

https://github。com/Oneflow-Inc/oneflow/blob/master/oneflow/core/device/cuda_util。h#L91-L94

)進行運算

在 Compute 函式中,呼叫了

RUN_CUDA_KERNEL

(也是定義在

cuda_util。h

這個檔案中)這個宏啟動核函式。

對應的 GPU kernel 類的實現見:

https://github。com/Oneflow-Inc/oneflow/blob/7ab4b0f08c86a6f8af08b44daa510725942288fb/oneflow/user/kernels/leaky_relu_yzh_kernel。cu#L32-L49

其中用到了啟動 kernel 的宏

RUN_CUDA_KERNEL

,它的定義是:

#define RUN_CUDA_KERNEL(func, device_ctx_ptr, thread_num, 。。。) \ func<<cuda_stream()>>>(__VA_ARGS__)

第一個引數是核函式名字

第二個引數是 device context,後續獲取對應的 cuda_stream

第三個引數是要啟動的執行緒數量,會根據執行緒數量來計算所需的 Block 數目。

因為 leaky relu 是 elementwise 運算,各個元素互不影響,所以我們啟動了 elem_cnt 個執行緒。

後續的註冊與 CPU 版本類似,這裡不再贅述。直接參考以下程式碼即可:

https://github。com/Oneflow-Inc/oneflow/blob/7ab4b0f08c86a6f8af08b44daa510725942288fb/oneflow/user/kernels/leaky_relu_yzh_kernel。cu#L51-L62

可以看到不同裝置類的 Compute 中大部分程式碼是重複的。一種更優的程式碼組織方式是用一個

。cpp

檔案完成 kernel 和註冊的邏輯,

。cu

檔案編寫 GPU Kernel 函式和 GPU 模板特化的程式碼,

。h

檔案用於定義和編寫註冊宏。可參考 dim_gather_kernel_*

https://github。com/Oneflow-Inc/oneflow/tree/master/oneflow/user/kernels

)中的程式碼。

OneFlow 為了適配多種裝置,還提供了 Primitive 元件,可以參考:Primitive PR

https://github。com/Oneflow-Inc/oneflow/pull/6234

匯出 functional 介面

關於 functional 介面層的詳細介紹在這裡:

https://github。com/Oneflow-Inc/oneflow/wiki/Functional-Interface

概括而言,functional 層起到了“上接 Python,下聯 C++”的作用:

┌─────────────┐ │ Module │ │ (Python) │ ├─────────────┤ │ │ │ Functional │ ├─────────────┤ │ │ │ Op/Kernels │ │ (C++) │ └─────────────┘

因此,在上文定義 op 和註冊 kernel 後,需要為運算元匯出 functional 介面,才能使使用者透過 Python 程式碼呼叫該運算元。

匯出 functional 介面分為以下幾個步驟:

實現對應的 functor 並註冊

在 oneflow/core/functional/functional_api。yaml 中新增介面描述

實現對應的 functor 並註冊

對於 leaky_relu_yzh op,在 activation_functor。cpp

https://github。com/Oneflow-Inc/oneflow/blob/7ab4b0f08c86a6f8af08b44daa510725942288fb/oneflow/core/functional/impl/activation_functor。cpp#L391-L421

) 中,對其進行定義:

class LeakyReluYZHFunctor { public: LeakyReluYZHFunctor() { op_ = CHECK_JUST(one::OpBuilder(“leaky_relu_yzh”)。Input(“x”)。Output(“y”)。Build()); } Maybe operator()(const std::shared_ptr& x, const float& alpha) const { MutableAttrMap attrs; JUST(attrs。SetAttr(“alpha”, alpha)); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } private: std::shared_ptr op_;};

在構造函數里,構造了

leaky_relu

這個op

實現

operator()

過載運算子,透過

Dispatch

呼叫構造好的 op,並分別傳入輸入,屬性

類似的我們也給 LeakyReluGrad 匯出 functional 介面,以便後續編寫求導邏輯使用。

最後我們需要註冊到 Functional Library:

https://github。com/Oneflow-Inc/oneflow/blob/7ab4b0f08c86a6f8af08b44daa510725942288fb/oneflow/core/functional/impl/activation_functor。cpp#L610-L611

m。add_functor(“LeakyReluYZH”); // 注意最後字串中的名字在後續的 functional_api。yaml 中會用到

透過

m。add_functor

註冊後的 functor,可以在 C++ 層使用,如透過

functional::LeakyRelu

就可以呼叫

LeakyReluFunctor

在 functional_api.yaml 中新增介面描述

functional 透過解析 yaml 配置檔案,在 build 過程中自動幫我們生成介面。

在functional_api。yaml

https://github。com/Oneflow-Inc/oneflow/blob/master/oneflow/core/functional/functional_api。yaml

)檔案中,編寫相關配置。

https://github。com/Oneflow-Inc/oneflow/pull/8350/files#diff-4b35c1dcdbae81b75439ba570bc149554ca85b83757430613fcb612ae25972afR572-R579

- name: “leaky_relu_yzh” signature: “Tensor (Tensor x, Float alpha) => LeakyReluYZH” bind_python: True

其中

name

表示匯出到 Python 介面後函式的名字,比如匯出後在 Python 下使用就是

flow。_C。leaky_relu_yzh(。。。)

signature

用於描述介面原型及 C++ 程式碼的對應關係。

=>

左邊的為原型;

=>

右邊為對應的 Functional Library 中的名字。這裡

LeakyRelu

和前面匯出時指定的字串是一致的。

bind_python

,表示這個介面是否需要繫結 Python 介面 。比如下面的

leaky_relu_grad

,我們不會在 Python 層用到(但會在 C++ 層求導使用),所以設定為 False。

完成以上工作後,新增的運算元已經支援正向運算,編譯好程式碼便可以進行如下簡單的測試:

import oneflow as flow import numpy as npx_tensor = flow。Tensor(np。random。randn(3, 3))out = flow。_C。leaky_relu_yzh(x_tensor, alpha=0。2)

但是,還需要註冊反向,才能支援反向傳播。我們也先將反向需要的

LeakyReluGrad

匯出為 functional 介面。

- name: “leaky_relu_yzh_grad” signature: “Tensor (Tensor x, Tensor dy, Float alpha) => LeakyReluYZHGrad” bind_python: False實現用於求導的反向邏輯

反向傳播的本質就是高數中的鏈式法則,只不過 Autodiff 將鏈式法則變得模組化、易複用。

可以先閱讀 CSC321 Lecture 10: Automatic Differentiation(

https://www。cs。toronto。edu/~rgrosse/courses/csc321_2018/slides/lec10。pdf

瞭解 autodiff 的基本概念。

從邏輯上而言,一個運算元在反向過程中能夠求導數,一般需要以下資訊:

正向過程中的輸入、輸出

正向過程的 attr

反向過程中上一層(正向過程中的下一層)傳遞過來的正向輸出的梯度

未來 Graph 模式和 Eager 模式下的反向邏輯會合並,但目前還是需要分別註冊。

為 Eager 模式註冊反向

求導部分在 oneflow/core/autograd/gradient_funcs/activation。cpp

https://github。com/Oneflow-Inc/oneflow/pull/8350/files#diff-36aeebf7fd5a8b88bd5af87774e7b0b4f76fed42cfb75044d6eebdfe65628347R213-R253

)完成

主要有以下幾部分:

LeakyReluYZHCaptureState :用於儲存資料的結構體

這是一個簡單的結構體,繼承自

AutoGradCaptureState

,用於儲存 op 的屬性,以便於後續求導。

struct LeakyReluYZHCaptureState : public AutoGradCaptureState { bool requires_grad; // 輸入x是否需要梯度 float alpha=0。0; // 輸入的引數alpha};

LeakyReluYZH 類:繼承自 OpExprGradFunction 的類。需要重寫三個函式:

Init

Capture

Apply

class LeakyReluYZH : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { //。。。 } Maybe Capture(LeakyReluYZHCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { //。。。 } Maybe Apply(const LeakyReluYZHCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { //。。。 }};

Init:做的是一些初始化的工作,可以根據前向 op 的配置,來初始化屬性。

Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); }

Capture:用於捕捉前向的 Tensor,屬性,用於後續的求導。

以 LeakyReluYZH 為例子,我們需要得到:a) 輸入的 Tensor,當 Tensor 數值大於 0,梯度為 1,當小於 0,梯度為 alpha b) alpha的數值

Maybe Capture(LeakyReluYZHCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs。size(), 1); // 判斷輸入個數是否為1 ctx->requires_grad = inputs。at(0)->requires_grad(); // 判斷輸入是否需要梯度 if (!ctx->requires_grad) { return Maybe::Ok(); } // 如果不需要梯度,也就不需要求導了,直接返回 Maybe::Ok() ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->alpha = JUST(composed_attrs。GetAttr(“alpha”)); // 獲取 alpha,並存入 ctx->alpha 中 ctx->SaveTensorForBackward(inputs。at(0)); // 呼叫 SaveTensorForBackward 方法,儲存輸入的 Tensor return Maybe::Ok(); }

Apply:實際計算梯度的函式,我們可以拿到先前的 Tensor,並呼叫 functional 介面下注冊的 LeakyReluGrad,求得梯度,並返回

Maybe Apply(const LeakyReluYZHCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads。size(), 1); // 檢查梯度 Tensor 個數是否為 1 in_grads->resize(1); // 因為輸入只有一個,所以我們 resize(1) if (ctx->requires_grad) { const auto& x = ctx->SavedTensors()。at(0); // 呼叫 SavedTensors 介面,拿到先前透過 SaveTensorForBackward 介面儲存的 Tensor in_grads->at(0) = JUST(functional::LeakyReluYZHGrad(x, out_grads。at(0), ctx->alpha)); // 拿到 x,dy,alpha,傳入給 LeakyReluYZHGrad 計算,並將梯度返回給 in_grads->at(0) } return Maybe::Ok(); }

最後一步是註冊,第一個引數是 op type name,第二個引數是繼承自

OpExprGradFunction

的類。

REGISTER_OP_EXPR_GRAD_FUNCTION(“leaky_relu_yzh”, LeakyReluYZH); // 第二個引數是用於求導的類名

為 Graph 模式註冊反向

為 Graph 模式註冊 leaky_relu_yzh op 的反向程式碼在:

https://github。com/Oneflow-Inc/oneflow/pull/8350/files#diff-ef94ddb8f5c25689f2c6fab7a9675f16c95a22018a8c01647b4398ce2fb85a8bR81-R97

REGISTER_USER_OP_GRAD(“leaky_relu_yzh”) 。SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { // 根據前向的 op type name,拼湊出一個 leaky_relu_yzh_grad_op_name (leaky_relu_yzh_grad) const std::string leaky_relu_yzh_grad_op_name = ctx->FwOp()。op_name() + “_grad”; ctx->DefineOp(leaky_relu_yzh_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { // 構建一個 op(op type name 為 leaky_relu_yzh_grad) // 把前向輸出 y 的梯度,作為 leaky_relu_yzh_grad 的輸入 dy // 把前向的 x 作為 leaky_relu_yzh_grad 的輸入 x // 輸出為 dx // attr alpha 同前向一樣 return builder。OpTypeName(“leaky_relu_yzh_grad”) 。InputBind(“dy”, ctx->FwOp()。output_grad(“y”, 0)) 。InputBind(“x”, ctx->FwOp()。input(“x”, 0)) 。Attr(“alpha”, ctx->FwOp()。attr(“alpha”)) 。Output(“dx”) 。Build(); }); // 把 leaky_relu_yzh_grad_op_name 運算元的輸出 dx 的結果 // 繫結到前向輸入 x 的反向梯度上 // 即: // leaky_relu_yzh 的輸入 x 的梯度 = leaky_relu_yzh_grad 的輸出 dx ctx->FwOp()。InputGradBind(user_op::OpArg(“x”, 0), [&ctx, &leaky_relu_yzh_grad_op_name]() -> const std::string& { return ctx->GetOp(leaky_relu_yzh_grad_op_name)。output(“dx”, 0); }); return Maybe::Ok(); });

3

測試與文件

本文覆蓋的內容完成後,只是“跑通”運算元,還需要進一步完善,包括為運算元新增測試和 API 文件,這些將在後續的文章中介紹。

歡迎下載體驗 OneFlow v0.8.0 最新版本:

https://github。com/Oneflow-Inc/oneflow/

相關文章

頂部