Source code for onnx_diagnostic.tasks.text_classification
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from ..helpers.config_helper import update_config, check_hasattr
__TASK__ = "text-classification"
[docs]
def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
"""Reduces a model size."""
check_hasattr(config, "num_attention_heads", "num_hidden_layers")
kwargs = dict(
num_hidden_layers=min(config.num_hidden_layers, 2),
num_attention_heads=min(config.num_attention_heads, 4),
)
update_config(config, kwargs)
return kwargs
[docs]
def get_inputs(
model: torch.nn.Module,
config: Optional[Any],
batch_size: int,
sequence_length: int,
dummy_max_token_id: int,
**kwargs, # unused
):
"""
Generates inputs for task ``text-classification``.
Example:
::
input_ids:T7s1x13[101,72654:A16789.23076923077],
token_type_ids:T7s1x13[0,0:A0.0],
attention_mask:T7s1x13[1,1:A1.0])
"""
batch = torch.export.Dim("batch", min=1, max=1024)
seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
shapes = {
"input_ids": {0: batch, 1: seq_length},
"token_type_ids": {0: batch, 1: seq_length},
"attention_mask": {0: batch, 1: seq_length},
}
inputs = dict(
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
torch.int64
),
token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
)
return dict(inputs=inputs, dynamic_shapes=shapes)
[docs]
def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
"""
Inputs kwargs.
If the configuration is None, the function selects typical dimensions.
"""
if config is not None:
check_hasattr(config, "vocab_size")
kwargs = dict(
batch_size=2,
sequence_length=30,
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
)
return kwargs, get_inputs