Convert onnx models to pytorch.

Overview

onnx2torch

onnx2torch is an ONNX to PyTorch converter. Our converter:

  • Is easy to use – Convert the ONNX model with the function call convert;
  • Is easy to extend – Write your own custom layer in PyTorch and register it with @add_converter;
  • Convert back to ONNX – You can convert the model back to ONNX using the torch.onnx.export function.

If you find an issue, please let us know! And feel free to create merge requests.

Please note that this converter covers only a limited number of PyTorch / ONNX models and operations.
Let us know which models you use or want to convert from onnx to torch here.

Installation

From PyPi

pip install onnx2torch

Usage

Below you can find some examples of use.

Convert

import torch
from onnx2torch.converter import convert

# Path to ONNX model
onnx_model_path = '/some/path/mobile_net_v2.onnx'
# You can pass the path to the onnx model to convert it or...
torch_model_1 = convert(onnx_model_path)

# Or you can load a regular onnx model and pass it to the converter
onnx_model = onnx.load(onnx_model_path)
torch_model_2 = convert(onnx_model)

Execute

We can execute the returned PyTorch model in the same way as the original torch model.

import onnxruntime as ort
# Create example data
x = torch.ones((1, 2, 224, 224)).cuda()

out_torch = torch_model_1(x)

ort_sess = ort.InferenceSession(onnx_model_path)
outputs_ort = ort_sess.run(None, {'input': x.numpy()})

# Check the Onnx output against PyTorch
print(torch.max(torch.abs(outputs_ort - out_torch.detach().numpy())))
print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.e-7))

Models

We have tested the following models:

  • ResNet50
  • SSDLite with MobileNetV2 backbone

How to add new operations to converter

Here we show how to add the module:

  1. Supported by both PyTorch and ONNX and has the same behaviour.
    An example of such a module is Relu
@add_converter(operation_type='Relu', version=6)
@add_converter(operation_type='Relu', version=13)
@add_converter(operation_type='Relu', version=14)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
    return OperationConverterResult(
        torch_module=nn.ReLU(),
        onnx_mapping=onnx_mapping_from_node(node=node),
    )

Here we have registered an operation named Relu for opset versions 6, 13, 14.
Note that the torch_module argument in OperationConverterResult must be a torch.nn.Module, not just a callable object!
If Operation's behaviour differs from one opset version to another, you should implement it separately.

  1. Operations supported by PyTorch and ONNX BUT have different behaviour
class OnnxExpand(nn.Module):

    @staticmethod
    def _do_forward(input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor:
        return input_tensor * torch.ones(torch.Size(shape), dtype=input_tensor.dtype, device=input_tensor.device)

    def forward(self, *args) -> torch.Tensor:
        if torch.onnx.is_in_onnx_export():
            with skip_torch_tracing():
                output = self._do_forward(*args)
                return _ExpandExportToOnnx.set_output_and_apply(output, *args)

        return self._do_forward(*args)


class _ExpandExportToOnnx(CustomExportToOnnx):

    @staticmethod
    def symbolic(graph: torch_C.Graph, *args, **kwargs) -> torch_C.Value:
        return graph.op('Expand', *args, **kwargs, outputs=1)


@add_converter(operation_type='Expand', version=8)
@add_converter(operation_type='Expand', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:  # pylint: disable=unused-argument
    return OperationConverterResult(
        torch_module=OnnxExpand(),
        onnx_mapping=onnx_mapping_from_node(node=node),
    )

Here we have used a trick to convert the model from torch back to ONNX by defining the custom _ExpandExportToOnnx.

Comments
  • IndexError because pytorch does not support int32 for indexing like onnx

    IndexError because pytorch does not support int32 for indexing like onnx

    Been trying a bunch of things to solve this now but the error message isn't very helpful. It just sounds like it's just the input dtype that's handled incorrectly.

      File "/home/richard/miniconda3/envs/3.9.13/lib/python3.9/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
        raise e.with_traceback(None)
    IndexError: tensors used as indices must be long, byte or bool tensors
    

    Here is the onnx file I'm trying to convert https://drive.google.com/file/d/1FX_D6dcYEoVssr-y29F5RLbhmhf4RmyZ/view?usp=sharing

    It is supposed to take a tensor of type long. Here's an example in json: https://drive.google.com/file/d/1yTSxwOY10g0cULEVt3KQZWSzziDpD8bC/view?usp=sharing

    This is how I try to run it:

        model = torch.load(torch_path).eval().requires_grad_(False)
    
        tokens = torch.tensor(
            json.loads(Path(f"tests/{model_name}/tokens.json").read_text())
        ).long()
        text_encodings = torch.from_numpy(np.array(model(tokens)))
    

    The onnx file works with onnxruntime and gives the correct result.

        tokens = np.array(
            json.loads(Path(f"tests/{model_name}/tokens.json").read_text())
        ).astype(np.int64)
    
        providers = ["CPUExecutionProvider"]
        m = rt.InferenceSession(onnx_path, providers=providers)
        text_encoding = m.run(output_names, dict(inputs=tokens))
    

    Do you have any tips for what could be wrong?

    opened by samedii 12
  • Key error for Reshape when trying to convert ONNX model

    Key error for Reshape when trying to convert ONNX model

    Do you think this error occuring because the weights name is missing in the ONNX file or because the operation isn't found?

    I'm getting the right results if I load and run the model with ONNX runtime at least.

      File "/home/richard/miniconda3/envs/3.9.13/lib/python3.9/site-packages/onnx2torch/converter.py", line 109, in convert
        torch_module, onnx_mapping = converter(onnx_node, onnx_graph)
      File "/home/richard/miniconda3/envs/3.9.13/lib/python3.9/site-packages/onnx2torch/node_converters/gemm.py", line 65, in _
        weights = graph.initializers[weights_value_name]
    KeyError: 'Reshape__2177:0'
    

    Here is a link to the ONNX file if you want to try it yourself https://drive.google.com/file/d/1P_Bl7n2hbUoOhfMh_9UJfgK_N8kP_xjB/view?usp=sharing

    I'm trying to make this model https://paperswithcode.com/paper/lit-zero-shot-transfer-with-locked-image-text easier to use for people

    bug 
    opened by samedii 5
  • release on conda-forge or similar?

    release on conda-forge or similar?

    Hi there,

    Are there plans to release the package for conda installation through conda-forge or another channel? I would love to be able to conda install this package so it will track in my conda environments.

    (My personal machine is on Mac OS, and I also work on a Linux server, so if I could put in requests for build OSs it would be those.)

    thank you for your consideration!

    feat 
    opened by monicathieu 5
  • feat: Add LRN opset 13 support

    feat: Add LRN opset 13 support

    Hi y'all,

    After installing the package (through conda-forge! 🙏🏼 ), I went to convert my ONNX model, which is a modified AlexNet model (same layer architecture, just retrained with different image classes and new weights in the last layer). At this point, I realized that local response normalization was not yet implemented in onnx2torch! I looked at the relevant pytorch docs and ONNX docs, and it looks like both of them implement the formula directly from the original AlexNet paper, just with slightly different argument names/order. Accordingly, I thought I might be able to handle adding a new converter for LRN.

    I do want to note that I haven't added any tests for the converter, because that was beyond my Python/pytorch ability (this is all part of me learning!), but when I use my local version of the package to import my modified AlexNet model, it seeeeems not to have been mangled.

    If you'd like me to work on adding the tests, I can, but it would take me some time to figure out the syntax. I would be thrilled if it ended up being very fast for you guys to add the tests yourselves.

    feat 
    opened by monicathieu 4
  • A lot of Ops with their implementations.

    A lot of Ops with their implementations.

    Hi, developer team of onnx2torch. I am currently developing an neural network quantization framework: https://github.com/openppl-public/ppq/tree/master/ppq. The really interesting part is that we both need to run an onnx model with pytorch : ) I am glad to share our operator implementations with you: https://github.com/openppl-public/ppq/blob/master/ppq/executor/op/torch/default.py

    We support following onnx operators by now(Still work in progress):

    1. 'Abs': Abs_forward,
    2. 'AdaptiveAvgPool2d': AdaptiveAvgPool2d_forward,
    3. 'And':And_forward,
    4. 'Add': Add_forward,
    5. 'ArgMax': ArgMax_forward,
    6. 'AveragePool': AveragePool_forward,
    7. 'BatchNormalization': BatchNormalization_forward,
    8. 'Cast': Cast_forward,
    9. 'Clip': Clip_forward,
    10. 'Concat': Concat_forward,
    11. 'Constant': Constant_forward,
    12. 'ConstantOfShape': ConstantOfShape_forward,
    13. 'Conv': Conv_forward,
    14. 'ConvTranspose': ConvTranspose_forward,
    15. 'Cos': Cos_forward,
    16. 'Div': Eltwise_forward,
    17. 'Equal': Equal_forward,
    18. 'Exp': UnaryEltwise_forward,
    19. 'Expand': Expand_forward,
    20. 'Flatten': Flatten_forward,
    21. 'Gather': Gather_forward,
    22. 'GatherElements': Gather_forward,
    23. 'GatherND': GatherND_forward,
    24. 'Gelu': Gelu_forward,
    25. 'Gemm': Gemm_forward,
    26. 'grid_sampler': Grid_sampler_forward,
    27. 'GlobalAveragePool': AveragePool_forward,
    28. 'GlobalMaxPool': MaxPool2d_forward,
    29. 'Greater': Greater_forward,
    30. 'LayerNorm': LayerNorm_forward,
    31. 'LeakyRelu': LeakyRelu_forward,
    32. 'Less': Less_forward,
    33. 'LogSoftmax': LogSoftmax_forward,
    34. 'MatMul': MatMul_forward,
    35. 'Max': Eltwise_forward,
    36. 'MaxPool': MaxPool2d_forward,
    37. 'Min': Eltwise_forward,
    38. 'Mul': Mul_forward,
    39. 'MultiHeadAttention': MultiHeadAttention_forward,
    40. 'NonMaxSuppression': _NMS_forward,
    41. 'NonZero': NonZero_forward,
    42. 'Not': Not_forward,
    43. 'Pad': Pad_forward,
    44. 'PRelu': PRelu_forward,
    45. 'Range': Range_forward,
    46. 'ReduceL2': ReduceL2_forward,
    47. 'ReduceMax': ReduceMax_forward,
    48. 'ReduceMean': ReduceMean_forward,
    49. 'ReduceSum': ReduceSum_forward,
    50. 'Relu': UnaryEltwise_forward,
    51. 'Reshape': Reshape_forward,
    52. 'Resize': Resize_forward,
    53. 'ScatterElements': ScatterElements_forward,
    54. 'ScatterND': ScatterND_forward,
    55. 'Shape': Shape_forward,
    56. 'Sigmoid': UnaryEltwise_forward,
    57. 'Sin': Sin_forward,
    58. 'Slice': Slice_forward,
    59. 'Softmax': Softmax_forward,
    60. 'Softplus': Softplus_forward,
    61. 'Split': Split_forward,
    62. 'Squeeze': Squeeze_forward,
    63. 'Sub': Eltwise_forward,
    64. 'Tile': Tile_forward,
    65. 'TopK': TopK_forward,
    66. 'Transpose': Transpose_forward,
    67. 'Unsqueeze': Unsqueeze_forward,
    68. 'Where': Where_forward,
    69. 'Sqrt': Sqrt_forward,
    70. 'Log': Log_forward,
    71. 'Floor': Floor_forward,
    72. 'RoiAlign': RoiAlign_forward,
    73. 'MMCVRoiAlign': MMCVRoiAlign_forward,
    74. 'SpaceToDepth': SpaceToDepth_forward,
    75. 'DepthToSpace': DepthToSpace_forward,
    76. 'Scale': Scale_forward, # caffe op
    77. 'Tanh': Tanh_forward,
    78. 'Pow': Pow_forward,
    79. 'Crop': Crop_forward, # caffe op
    80. 'ChannelShuffle': ChannelShuffle_forward, # caffe op
    81. 'InstanceNormalization': InstanceNormalization_forward,
    82. 'Parameter': Parameter_forward, # caffe op
    83. 'Interp': Interp_forward, # caffe op
    84. 'CaffeArgMax': CaffeArgMax_forward, # caffe op
    85. 'HardSigmoid': HardSigmoid_forward,
    86. 'HardSwish': HardSwish_forward,
    87. 'Neg': Neg_forward,
    88. 'GRU': GRU_forward,
    89. 'PPQDeviceSwitch': PPQDeviceSwitch_forward,
    90. 'Identity': Identity_forward,
    91. 'OneHot': Onehot_forward,
    92. 'Reciprocal': Reciprocal_forward,
    93. 'LSTM': LSTM_forward,
    Stale 
    opened by ZhangZhiPku 4
  • Does not support the convert of xception?

    Does not support the convert of xception?

    I successfully converted the Xception ONNX model file to a .pt file, but the predicted data is inconsistent with the original ONNX file.

    Is there any unsupported operation? but also did not throw any error message when converting.

    bug 
    opened by koon-kai 3
  • GlobalAveragePool: make it compatible with PyTorch FX QAT

    GlobalAveragePool: make it compatible with PyTorch FX QAT

    range() is not compatible with torch.fx symbolic tracing yet [1]. I created a slightly different module that avoids range() in forward() when possible (i.e., inferred shapes exist).

    [1] https://github.com/pytorch/pytorch/issues/44851

    Here is an example of the current issue:

    import onnx
    import onnx.checker
    import onnx.helper
    import onnx2torch
    import torch.quantization.quantize_fx
    
    # https://github.com/onnx/models/blob/main/vision/classification/squeezenet/model/squeezenet1.0-12.onnx
    model = onnx.load_model('squeezenet1.0-12.onnx')
    net = onnx2torch.convert(model)
    
    net.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
    qconfig_dict = {'': net.qconfig}
    torch.quantization.quantize_fx.prepare_qat_fx(net, qconfig_dict)
    

    torch.fx fails even after adding "torch.fx.wrap('len')" to onnx2torch/node_converters/global_average_pool.py:

    Traceback (most recent call last):
      File "/home/yen/var/local/Computer/machine-learning/onnx2torch/t.py", line 20, in <module>
        main()
      File "/home/yen/var/local/Computer/machine-learning/onnx2torch/t.py", line 17, in main
        quantize_fx.prepare_qat_fx(net, qconfig_dict)
      File "/usr/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 564, in prepare_qat_fx
        return _prepare_fx(
      File "/usr/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 237, in _prepare_fx
        graph_module = GraphModule(model, tracer.trace(model))
      File "/usr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 566, in trace
        self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
      File "<eval_with_key>.0", line 71, in forward
      File "/usr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 556, in module_call_wrapper
        return self.call_module(mod, forward, args, kwargs)
      File "/usr/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 158, in call_module
        return super().call_module(m, forward, args, kwargs)
      File "/usr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 372, in call_module
        return forward(*args, **kwargs)
      File "/usr/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 552, in forward
        return _orig_module_call(mod, *args, **kwargs)
      File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/yen/var/local/Computer/machine-learning/onnx2torch/onnx2torch/node_converters/global_average_pool.py", line 19, in forward
        x_dims = list(range(2, len(input_tensor.shape)))
    TypeError: 'Proxy' object cannot be interpreted as an integer
    
    fix 
    opened by yan12125 3
  • CVE-2007-4559 Patch

    CVE-2007-4559 Patch

    Patching CVE-2007-4559

    Hi, we are security researchers from the Advanced Research Center at Trellix. We have began a campaign to patch a widespread bug named CVE-2007-4559. CVE-2007-4559 is a 15 year old bug in the Python tarfile package. By using extract() or extractall() on a tarfile object without sanitizing input, a maliciously crafted .tar file could perform a directory path traversal attack. We found at least one unsantized extractall() in your codebase and are providing a patch for you via pull request. The patch essentially checks to see if all tarfile members will be extracted safely and throws an exception otherwise. We encourage you to use this patch or your own solution to secure against CVE-2007-4559. Further technical information about the vulnerability can be found in this blog.

    If you have further questions you may contact us through this projects lead researcher Kasimir Schulz.

    Stale 
    opened by TrellixVulnTeam 2
  • Gemm Bug

    Gemm Bug

    https://github.com/ENOT-AutoDL/onnx2torch/blob/2791442c88af2bd243d0092963838ce940f8f1b0/onnx2torch/node_converters/gemm.py#L67 is wrong, it should be if bias is None or bias.dim() == 1:

    Stale 
    opened by JiangYongYu1 2
  • feat: add LRN operator support (attempt 2)

    feat: add LRN operator support (attempt 2)

    This PR is for the same additions as #96 , but I had to make a new one. I renamed the "main" branch on my fork so that pre-commit no-commit-to-branch would pass, but in doing so, it deleted the branch that was associated with the old PR and auto-closed the PR. So, here we are. Phew! I think that's it though?

    opened by monicathieu 2
  • [Code Quality] Remove trailing whitespace and add isort for module import

    [Code Quality] Remove trailing whitespace and add isort for module import

    Use pre-commit to auto remove all trailing whitespace.

    Usage

    Use pip (or brew in macOS) to install pre-commit first

    pip install pre-commit
    

    Install the git hook scripts to this project

    pre-commit install
    

    Run against all the files

    $ pre-commit run --all-files
    

    Run for some files

    $ pre-commit run --files a.py b.py xxx.py
    enhancement 
    opened by triple-Mu 2
  • KeyError: 'Tensor

    KeyError: 'Tensor "onnx::Clip_2052" is not found in constant values'

    Trying to convert a LayoutMl model to pytorch model getting this error :

    KeyError Traceback (most recent call last) /usr/local/lib/python3.8/dist-packages/onnx2torch/node_converters/clip.py in _(node, graph) 56 try: ---> 57 min_val = float(get_const_value(min_name, graph)) if min_name is not None else None 58 max_val = float(get_const_value(max_name, graph)) if max_name is not None else None

    3 frames KeyError: 'Tensor "onnx::Clip_2052" is not found in constant values'

    The above exception was the direct cause of the following exception:

    NotImplementedError Traceback (most recent call last) /usr/local/lib/python3.8/dist-packages/onnx2torch/node_converters/clip.py in _(node, graph) 58 max_val = float(get_const_value(max_name, graph)) if max_name is not None else None 59 except KeyError as exc: ---> 60 raise NotImplementedError('Dynamic value of min/max is not implemented') from exc 61 62 torch_module = _create_torch_module(min_val=min_val, max_val=max_val)

    NotImplementedError: Dynamic value of min/max is not implemented

    opened by NeoKun004 0
  • save converted yolov5

    save converted yolov5

    Exported the module to folder using "to_folder" and encounter "NameError: name 'onnx2torch_converter_lambda_' is not defined" error when inference after import the module. Any tips what's wrong?

    opened by clh-b 2
  • Is There any possible way to have OnnxGemm converted into torch.nn.Linear?

    Is There any possible way to have OnnxGemm converted into torch.nn.Linear?

    Hi, It is kind of you to share your work. Very appreciate it.

    I have noticed that the GEMM layer of the onnx graph will be converted into a object of OnnxGemm.

    for example: image

    then we check the pytorch model image and its forward image

    I wonder that is it possible to have these GEMMs converted into PyTorch native module? And it will be much more scalable.

    Thank you!!

    opened by feanor21115 0
  • GatherND missing

    GatherND missing

    While mentioned in release v1.2.0, GatherND is not implemented. See https://github.com/ENOT-AutoDL/onnx2torch/pull/25. Are there any plans to still add the operator?

    opened by jonas-doevenspeck 4
Releases(v1.5.4)
  • v1.5.4(Nov 14, 2022)

  • v1.5.3(Sep 16, 2022)

  • v1.5.2(Sep 9, 2022)

  • v1.5.1(Sep 1, 2022)

  • v1.5.0(Aug 30, 2022)

    You've been waiting, you've asked... And here it is a new release(1.5.0)!

    In this release:

    • Fixed shape inference of big models (now you can load your lovely big language model, like gptj6B);
    • Fixed BarchNorm converter according onnx specs;
    • Added asymmetric padding for both Conv and MaxPool.

    Thank you for your discussions and issues! We always ready to help you!

    Source code(tar.gz)
    Source code(zip)
  • v1.4.1(Jul 12, 2022)

  • v1.4.0(Jul 6, 2022)

    :rocket: New day - new release! :rocket: New features in release 1.4.0:

    • Added operations:

      • Einsum,
      • Reciprocal,
      • Neg,
      • Prelu,
      • Mean,
      • Min,
      • Max,
      • CumSum,
      • Sum.
    • Fixes:

      • GlobalAveragePool compatible with PyTorch FX QAT,
      • Compatibility with torch 1.12,
      • Allow to ignore bs+ch dimensions for size input in Resize.
      • Check roi only for tf_crop_and_resize mode.
    • Code style and workflows:

    We thank all contributors for their work! Together we will make onnx2torch great! :star: 10^6

    Source code(tar.gz)
    Source code(zip)
  • v1.3.0(Apr 20, 2022)

    New features in release 1.3.0:

    • Add Pad operation
    • Add Hardswish, Celu, Elu, Selu, Softplus, Softsign functions
    • Add GatherElements operation
    • Add new model tests

    Fixes:

    • Fix ConvTranspose operation
    Source code(tar.gz)
    Source code(zip)
  • v1.2.5(Apr 7, 2022)

  • v1.2.0(Mar 23, 2022)

    New features in release 1.2.0:

    • Support VIT, SWIN, FasterRcnn models;
    • add RoiAlign operation;
    • Add Floor, Ceil, Round and trigonometric functions;
    • New model tests for classification, detection and segmentation;

    Fixes:

    • Fix ScatterND and its speed;
    • Fix names of plaseholders;
    • Fix optional arguments if last arguments are not passed;
    • Fix checking of empty inputs;
    • Fix avg pool operation;
    Source code(tar.gz)
    Source code(zip)
  • v1.1.0(Jan 18, 2022)

    New features in release 1.1.0:

    Support of optional input arguments;

    Fixed dynamic axes for squeeze and unsqueeze.

    Also added new operations listed below:

    Resize Reduces operations (ReduceL1, ReduceL2, ReduceLogSum, ReduceLogSumExp, ReduceMax, ReduceMean, ReduceMin, ReduceProd, ReduceSum, ReduceSumSquare) Logical Operaations (Or, Xor, And, Not) HardSigmoid, LeakyRelu Pow, Sqrt

    Source code(tar.gz)
    Source code(zip)
  • v1.0.0(Dec 14, 2021)

Owner
ENOT
ENOT
Norm-based Analysis of Transformer

Norm-based Analysis of Transformer Implementations for 2 papers introducing to analyze Transformers using vector norms: Kobayashi+'20 Attention is Not

Goro Kobayashi 52 Dec 05, 2022
a minimal terminal with python 😎😉

Meterm a terminal with python 😎 How to use Clone Project: $ git clone https://github.com/motahharm/meterm.git Run: in Terminal: meterm.exe Or pip ins

Motahhar.Mokfi 5 Jan 28, 2022
Code for our EMNLP 2021 paper “Heterogeneous Graph Neural Networks for Keyphrase Generation”

GATER This repository contains the code for our EMNLP 2021 paper “Heterogeneous Graph Neural Networks for Keyphrase Generation”. Our implementation is

Jiacheng Ye 12 Nov 24, 2022
HuSpaCy: industrial-strength Hungarian natural language processing

HuSpaCy: Industrial-strength Hungarian NLP HuSpaCy is a spaCy model and a library providing industrial-strength Hungarian language processing faciliti

HuSpaCy 120 Dec 14, 2022
Code and real data for the paper "Counterfactual Temporal Point Processes", available at arXiv.

counterfactual-tpp This is a repository containing code and real data for the paper Counterfactual Temporal Point Processes. Pre-requisites This code

Networks Learning 11 Dec 09, 2022
Python package for missing-data imputation with deep learning

MIDASpy Overview MIDASpy is a Python package for multiply imputing missing data using deep learning methods. The MIDASpy algorithm offers significant

MIDASverse 77 Dec 03, 2022
A python script to convert images to animated sus among us crewmate twerk jifs as seen on r/196

img_sussifier A python script to convert images to animated sus among us crewmate twerk jifs as seen on r/196 Examples How to use install python pip i

41 Sep 30, 2022
Codebase for INVASE: Instance-wise Variable Selection - 2019 ICLR

Codebase for "INVASE: Instance-wise Variable Selection" Authors: Jinsung Yoon, James Jordon, Mihaela van der Schaar Paper: Jinsung Yoon, James Jordon,

Jinsung Yoon 50 Nov 11, 2022
Application of K-means algorithm on a music dataset after a dimensionality reduction with PCA

PCA for dimensionality reduction combined with Kmeans Goal The Goal of this notebook is to apply a dimensionality reduction on a big dataset in order

Arturo Ghinassi 0 Sep 17, 2022
We provided a matlab implementation for an evolutionary multitasking AUC optimization framework (EMTAUC).

EMTAUC We provided a matlab implementation for an evolutionary multitasking AUC optimization framework (EMTAUC). In this code, SBGA is considered a ba

7 Nov 24, 2022
Extract MNIST handwritten digits dataset binary file into bmp images

MNIST-dataset-extractor Extract MNIST handwritten digits dataset binary file into bmp images More info at http://yann.lecun.com/exdb/mnist/ Dependenci

Omar Mostafa 6 May 24, 2021
Pytoydl: A toy deep learning framework built upon numpy.

Documents: https://pytoydl.readthedocs.io/zh/latest/ Pytoydl A toy deep learning framework built upon numpy. You can star this repository to keep trac

28 Dec 10, 2022
NeuralCompression is a Python repository dedicated to research of neural networks that compress data

NeuralCompression is a Python repository dedicated to research of neural networks that compress data. The repository includes tools such as JAX-based entropy coders, image compression models, video c

Facebook Research 297 Jan 06, 2023
Tackling Obstacle Tower Challenge using PPO & A2C combined with ICM.

Obstacle Tower Challenge using Deep Reinforcement Learning Unity Obstacle Tower is a challenging realistic 3D, third person perspective and procedural

Zhuoyu Feng 5 Feb 10, 2022
Learning Features with Parameter-Free Layers (ICLR 2022)

Learning Features with Parameter-Free Layers (ICLR 2022) Dongyoon Han, YoungJoon Yoo, Beomyoung Kim, Byeongho Heo | Paper NAVER AI Lab, NAVER CLOVA Up

NAVER AI 65 Dec 07, 2022
Doubly Robust Off-Policy Evaluation for Ranking Policies under the Cascade Behavior Model

Doubly Robust Off-Policy Evaluation for Ranking Policies under the Cascade Behavior Model About This repository contains the code to replicate the syn

Haruka Kiyohara 12 Dec 07, 2022
TransPrompt - Towards an Automatic Transferable Prompting Framework for Few-shot Text Classification

TransPrompt This code is implement for our EMNLP 2021's paper 《TransPrompt:Towards an Automatic Transferable Prompting Framework for Few-shot Text Cla

WangJianing 23 Dec 21, 2022
MVP Benchmark for Multi-View Partial Point Cloud Completion and Registration

MVP Benchmark: Multi-View Partial Point Clouds for Completion and Registration [NEWS] 2021-07-12 [NEW 🎉 ] The submission on Codalab starts! 2021-07-1

PL 93 Dec 21, 2022
Official Repository for "Robust On-Policy Data Collection for Data Efficient Policy Evaluation" (NeurIPS 2021 Workshop on OfflineRL).

Robust On-Policy Data Collection for Data-Efficient Policy Evaluation Source code of Robust On-Policy Data Collection for Data-Efficient Policy Evalua

Autonomous Agents Research Group (University of Edinburgh) 2 Oct 09, 2022
DM-ACME compatible implementation of the Arm26 environment from Mujoco

ACME-compatible implementation of Arm26 from Mujoco This repository contains a customized implementation of Mujoco's Arm26 model, that can be used wit

1 Dec 24, 2021