load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native") load("//tools/build_defs:expect.bzl", "expect") load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") # @lint-ignore BUCKRESTRICTEDSYNTAX IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build USED_PT_BACKENDS = [ "CPU", "QuantizedCPU", "SparseCPU", # brings ~20 kb size regression ] def pt_operator_library( name, ops = [], exported_deps = [], check_decl = True, train = False, model = None, include_all_operators = False, include_base_operators = True, **kwargs): (model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information( name, model, ) ops = [op.strip() for op in ops] # If ops are specified, then we are in static selective build mode, so we append # base ops to this list to avoid additional special case logic in subsequent code, # unless include_base_operators is explicitly set to False (the default is True) if len(ops) > 0 and include_base_operators: ops.extend(PT_BASE_OPS) labels = kwargs.pop("labels", []) visibility = kwargs.pop("visibility", ["PUBLIC"]) # Sanity check the model name and versions. While the input to both is an array, the # codegen script only ever outputs a single item in the array so we can just assume that # here. If you ever need to depends on more than one assets, just break it up into a separate # BUCK targets. if model_assets or model_versions: if len(model_assets) != 1: fail("Model assets must be of size 1") if len(model_versions) != 1: fail("Model versions must be of size 1") # Is this a traced operator therefore has a YAML file with ops? yaml_option = "" if model_assets and len(model_assets) > 0: # We know these lists are only of length 1 via earlier assert. model_asset = model_assets[0] model_version = model_versions[0] # Pass the YAML file from this asset to the genrule below. yaml_dep = "{}_v{}_yaml".format(model_asset, model_version) fb_native.filegroup( name = yaml_dep, srcs = [ model_asset + ".yaml", ], # The visibility is not set to PUBLIC as this an internal detail. If you see this error # in your buck build flow, you are trying to use a hand-crafted "pt_operator_library" that # with parameters not supported outside of codegen targets! ) # Since all selective traced ops are created by automation, we can assume they # have a YAML file at this very location. If it doesn't exist, it means the targets # was hand-crafted which is not a support workflow for traced ops. yaml_option = "--models_yaml_path $(location fbsource//xplat/pytorch_models/build/{}/v{}:{})/{}.yaml".format(model_name, model_version, yaml_dep, model_asset) fb_xplat_genrule( name = name, out = "model_operators.yaml", cmd = ( "$(exe {exe}) " + "{optionally_root_ops} " + "{optionally_training_root_ops} " + "--rule_name {rule_name} " + "--output_path \"${{OUT}}\" " + "--model_name {model_name} " + "--dep_graph_yaml_path {dep_graph_yaml} " + "{optionally_model_yamls} " + "{optionally_model_versions} " + "{optionally_model_assets} " + "{optionally_model_traced_backends} " + "{optionally_include_all_operators}" ).format( exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml", rule_name = name, model_name = model_name, dep_graph_yaml = "none" if IS_OSS else "$(location fbsource//xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ", optionally_model_yamls = "" if (IS_OSS or yaml_option == None) else yaml_option, optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "", optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "", optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "", optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "", optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "", optionally_include_all_operators = "--include_all_operators " if include_all_operators else "", ), labels = labels + [ "pt_operator_library", "supermodule:android/default/pytorch", "supermodule:ios/default/public.pytorch", ] + (["pt_train_operator_library"] if train else []), visibility = visibility, **kwargs ) def validate_and_extract_model_information(name, model): model_name = name model_versions = None model_assets = None model_traced_backends = None if model != None: model_name = model.get("name") expect(model_name != None, "Expected Model Name to be present") model_versions = model.get("versions") expect(is_list(model_versions), "Expected model versions to be a list of string") for ver in model_versions or []: expect(is_string(ver), "Expected version '{}' to be string".format(str(ver))) model_assets = model.get("assets") expect( model_assets == None or is_list(model_assets), "Expected model assets to be a list of string if specified", ) for asset_name in model_assets or []: expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name))) model_traced_backends = model.get("traced_backends") expect( model_traced_backends == None or is_list(model_traced_backends), "Expected model traced backends to be a list of string if specified", ) if model_traced_backends != None: for backend in model_traced_backends: expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend))) expect( backend in USED_PT_BACKENDS, "Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)), ) return (model_name, model_versions, model_assets, model_traced_backends) # This file keeps a list of PyTorch operators used by any targets in # @fbsource//xplat/... # The purpose of the list is to avoid generating large number of unused # operator registration code / BUCK rules at build time. # See more detail at: https://fb.quip.com/ZVh1AgOKW8Vv PT_OPS_PRIM = [ "aten::str", "aten::list", "aten::__range_length", "aten::__derive_index", "prim::TupleUnpack", "prim::unchecked_cast", "aten::IntImplicit", "aten::FloatImplicit", "aten::ScalarImplicit", "aten::Bool.Tensor", "aten::Bool.int", "aten::Bool.float", "aten::Int.Tensor", "aten::Int.Scalar", "aten::Int.int", "aten::Int.bool", "aten::Int.str", "aten::Float.Tensor", "aten::Float.Scalar", "aten::Float.int", "aten::Float.bool", "aten::Float.str", "aten::format", "prim::NumToTensor.Scalar", "prim::RaiseException", "aten::Size", "aten::size", "prim::EnumName", "prim::EnumValue.int", "prim::EnumValue.float", "prim::EnumValue.str", "prim::TupleIndex", "aten::ne.int_list", "prim::unchecked_unwrap_optional", "prim::device", "prim::dtype", "aten::__not__", "aten::__is__", "aten::__isnot__", "aten::element_size", "aten::numel", "aten::dim", "aten::get_device", "aten::storage_offset", "aten::is_contiguous", "aten::select.t", "aten::__getitem__.t", "aten::append.t", "aten::reverse.t", "aten::extend.t", "aten::copy.t", "aten::_set_item.t", "aten::clear.t", "aten::Delete.t", "aten::insert.t", "aten::pop.t", "aten::add.t", "aten::add_.t", "aten::slice.t", "aten::list.t", "aten::mul.left_t", "aten::mul.right_", "aten::mul_.t", "aten::len.t", "aten::eq.int_list", "prim::Uninitialized", "prim::Print", "aten::eq.enum", "aten::ne.enum", "aten::dequantize.tensor", "aten::dequantize.any", "aten::add.str", "aten::eq.int", "aten::eq.float", "aten::eq.int_float", "aten::eq.float_int", "aten::eq", "aten::eq.str", "aten::ne.int", "aten::ne.float", "aten::ne.int_float", "aten::ne.float_int", "aten::ne", "aten::ne.str", "aten::lt.int", "aten::lt.float", "aten::lt.int_float", "aten::lt.float_int", "aten::lt", "aten::lt.str", "aten::gt.int", "aten::gt.float", "aten::gt.int_float", "aten::gt.float_int", "aten::gt", "aten::gt.str", "aten::le.int", "aten::le.float", "aten::le.int_float", "aten::le.float_int", "aten::le", "aten::le.str", "aten::ge.int", "aten::ge.float", "aten::ge.int_float", "aten::ge.float_int", "aten::ge", "aten::ge.str", "aten::add.int", "aten::add.float", "aten::add.int_float", "aten::add.float_int", "aten::add", "aten::sub.int", "aten::sub.float", "aten::sub.int_float", "aten::sub.float_int", "aten::sub", "aten::mul.int", "aten::mul.float", "aten::mul.int_float", "aten::mul.float_int", "aten::mul", "aten::__and__.bool", "aten::__or__.bool", "aten::__xor__.bool", "aten::floor.int", "aten::floor.float", "aten::floor.Scalar", "aten::ceil.int", "aten::ceil.float", "aten::ceil.Scalar", "aten::neg.int", "aten::neg.float", "aten::neg.Scalar", "aten::exp.int", "aten::exp.float", "aten::exp.Scalar", "aten::remainder.int", "aten::remainder.float", "aten::remainder.int_float", "aten::remainder.float_int", "aten::remainder", "aten::div.int", "aten::div.float", "aten::div", "aten::floordiv.int", "aten::floordiv.float", "aten::floordiv.int_float", "aten::floordiv.float_int", "aten::floordiv", "aten::pow.int", "aten::pow.float", "aten::pow.int_float", "aten::pow.float_int", "aten::pow.Scalar_Scalar", "aten::pow.int_to_int", "prim::min.int", "prim::min.float", "prim::min.int_float", "prim::min.float_int", "prim::min", "prim::max.int", "prim::max.float", "prim::max.int_float", "prim::max.float_int", "prim::max", "prim::type", "aten::len.Tensor", "aten::ord", "aten::lower", "aten::__contains__.str_list", "aten::len.str", "aten::__getitem__.str", "aten::copy_.Tensor", "aten::copy_.int", "aten::copy_.float", "aten::backward", "aten::index.Tensor_hacked_twin", "aten::_unsafe_index.Tensor_hacked_twin", "aten::_index_put_impl_.hacked_twin", "aten::index_put_.hacked_twin", "aten::index_put.hacked_twin", "aten::_unsafe_index_put.hacked_twin", "aten::to.prim_Device", "aten::to.prim_dtype", "prim::is_cuda", "prim::data", "prim::min.int_list", "prim::max.int_list", "prim::min.self_int", "prim::max.self_int", "prim::min.float_list", "prim::max.float_list", "prim::min.self_float", "prim::max.self_float", "prim::min.bool_list", "prim::max.bool_list", "prim::min.self_bool", "prim::max.self_bool", "aten::len.Dict_str", "aten::keys.str", "aten::values.str", "aten::__getitem__.Dict_str", "aten::get.str", "aten::get.default_str", "aten::setdefault.str", "aten::Delete.Dict_str", "aten::pop.Dict_str", "aten::pop.Dict_default_str", "aten::popitem.str", "aten::clear.str", "aten::update.str", "aten::items.str", "aten::copy.Dict_str", "aten::__contains__.str", "aten::_set_item.str", "aten::dict.str", "aten::len.Dict_int", "aten::keys.int", "aten::values.int", "aten::__getitem__.Dict_int", "aten::get.int", "aten::get.default_int", "aten::setdefault.int", "aten::Delete.Dict_int", "aten::pop.Dict_int", "aten::pop.Dict_default_int", "aten::popitem.int", "aten::clear.int", "aten::update.int", "aten::items.int", "aten::copy.Dict_int", "aten::__contains__.int", "aten::_set_item.int", "aten::dict.int", "aten::len.Dict_bool", "aten::keys.bool", "aten::values.bool", "aten::__getitem__.Dict_bool", "aten::get.bool", "aten::get.default_bool", "aten::setdefault.bool", "aten::Delete.Dict_bool", "aten::pop.Dict_bool", "aten::pop.Dict_default_bool", "aten::popitem.bool", "aten::clear.bool", "aten::update.bool", "aten::items.bool", "aten::copy.Dict_bool", "aten::__contains__.bool", "aten::_set_item.bool", "aten::dict.bool", "aten::len.Dict_float", "aten::keys.float", "aten::values.float", "aten::__getitem__.Dict_float", "aten::get.float", "aten::get.default_float", "aten::setdefault.float", "aten::Delete.Dict_float", "aten::pop.Dict_float", "aten::pop.Dict_default_float", "aten::popitem.float", "aten::clear.float", "aten::update.float", "aten::items.float", "aten::copy.Dict_float", "aten::__contains__.float", "aten::_set_item.float", "aten::dict.float", "aten::len.Dict_Tensor", "aten::keys.Tensor", "aten::values.Tensor", "aten::__getitem__.Dict_Tensor", "aten::get.Tensor", "aten::get.default_Tensor", "aten::setdefault.Tensor", "aten::Delete.Dict_Tensor", "aten::pop.Dict_Tensor", "aten::pop.Dict_default_Tensor", "aten::popitem.Tensor", "aten::clear.Tensor", "aten::update.Tensor", "aten::items.Tensor", "aten::copy.Dict_Tensor", "aten::__contains__.Tensor", "aten::_set_item.Tensor", "aten::dict.Tensor", "aten::__round_to_zero_floordiv.int", "aten::mathremainder.int", "aten::mathremainder.float", "aten::mathremainder.int_float", "aten::mathremainder.float_int", "aten::mathremainder", "aten::__and__.int", "aten::__or__.int", "aten::__xor__.int", "aten::__lshift__.int", "aten::__rshift__.int", "aten::round.int", "aten::round.float", "aten::round.Scalar", "aten::log.int", "aten::log.float", "aten::log.Scalar", "aten::log.int_int", "aten::log.float_float", "aten::log.int_float", "aten::log.float_int", "aten::log.Scalar_Scalar", "aten::log1p.int", "aten::log1p.float", "aten::log1p.Scalar", "aten::log10.int", "aten::log10.float", "aten::log10.Scalar", "aten::sqrt.int", "aten::sqrt.float", "aten::sqrt.Scalar", "aten::acos.int", "aten::acos.float", "aten::acos.Scalar", "aten::asin.int", "aten::asin.float", "aten::asin.Scalar", "aten::atan.int", "aten::atan.float", "aten::atan.Scalar", "aten::atan2.int", "aten::atan2.float", "aten::atan2.int_float", "aten::atan2.float_int", "aten::atan2.Scalar_Scalar", "aten::cos.int", "aten::cos.float", "aten::cos.Scalar", "aten::sin.int", "aten::sin.float", "aten::sin.Scalar", "aten::tan.int", "aten::tan.float", "aten::tan.Scalar", "aten::asinh.int", "aten::asinh.float", "aten::asinh.Scalar", "aten::atanh.int", "aten::atanh.float", "aten::atanh.Scalar", "aten::acosh.int", "aten::acosh.float", "aten::acosh.Scalar", "aten::sinh.int", "aten::sinh.float", "aten::sinh.Scalar", "aten::cosh.int", "aten::cosh.float", "aten::cosh.Scalar", "aten::tanh.int", "aten::tanh.float", "aten::tanh.Scalar", "aten::degrees.int", "aten::degrees.float", "aten::degrees.Scalar", "aten::radians.int", "aten::radians.float", "aten::radians.Scalar", "aten::fmod.int", "aten::fmod.float", "aten::fmod.int_float", "aten::fmod.float_int", "aten::fmod", "aten::factorial.int", "aten::isnan.float", "aten::isfinite.float", "aten::isinf.float", "aten::gamma.int", "aten::gamma.float", "aten::gamma.Scalar", "aten::erf.int", "aten::erf.float", "aten::erf.Scalar", "aten::erfc.int", "aten::erfc.float", "aten::erfc.Scalar", "aten::expm1.int", "aten::expm1.float", "aten::expm1.Scalar", "aten::fabs.int", "aten::fabs.float", "aten::fabs.Scalar", "aten::lgamma.int", "aten::lgamma.float", "aten::lgamma.Scalar", "prim::abs.int", "prim::abs.float", "prim::abs.Scalar", "aten::gcd.int", "aten::copysign.int", "aten::copysign.float", "aten::copysign.int_float", "aten::copysign.float_int", "aten::copysign", "aten::split", "aten::tensor.float", "aten::as_tensor.float", "aten::tensor.int", "aten::as_tensor.int", "aten::tensor.bool", "aten::as_tensor.bool", "aten::_infer_size", "aten::_no_grad_embedding_renorm_", "aten::tensor", "aten::as_tensor", "aten::as_tensor.list", "aten::_pack_sequence", "aten::_get_tracing_state", "aten::is_scripting", "aten::_no_grad_uniform_", "aten::_no_grad_normal_", "aten::_no_grad_fill_", "aten::_no_grad_zero_", ] PT_BASE_OPS = [ "aten::_coalesced_", "aten::_copy_from", "aten::_empty_affine_quantized", "aten::_empty_per_channel_affine_quantized", "aten::_indices", "aten::_nnz", "aten::_values", "aten::add", "aten::add_", "aten::arange", "aten::as_strided", "aten::as_strided_", "aten::cat", "aten::clone", "aten::coalesce", "aten::contiguous", "aten::copy_", "aten::copy_sparse_to_sparse_", "aten::dense_dim", "aten::dequantize", "aten::div", "aten::div_", "aten::empty", "aten::empty_like", "aten::empty_strided", "aten::eq", "aten::equal", "aten::expand", "aten::fill_", "aten::is_coalesced", "aten::is_complex", "aten::is_floating_point", "aten::is_leaf", "aten::is_nonzero", "aten::item", "aten::max", "aten::min", "aten::mul", "aten::mul_", "aten::narrow", "aten::ne", "aten::permute", "aten::q_per_channel_axis", "aten::q_per_channel_scales", "aten::q_per_channel_zero_points", "aten::q_scale", "aten::q_zero_point", "aten::qscheme", "aten::quantize_per_tensor", "aten::reshape", "aten::_reshape_alias", "aten::resize_", "aten::resize_as_", "aten::scalar_tensor", "aten::select", "aten::set_", "aten::size", "aten::slice", "aten::sparse_dim", "aten::sparse_resize_and_clear_", "aten::squeeze", "aten::squeeze_", "aten::stride", "aten::sub", "aten::sub_", "aten::sum", "aten::t", "aten::to", "aten::_to_copy", "aten::unsqueeze", "aten::view", "aten::zero_", "aten::zeros", "aten::zeros_like", ]