Source code for onnx_diagnostic.torch_models.untrained.llm_tiny_llm
from typing import Any, Dict
import transformers
[docs]
def get_tiny_llm(
batch_size: int = 2,
sequence_length: int = 30,
sequence_length2: int = 3,
dynamic_rope: bool = False,
use_static_cache: bool = False,
**kwargs,
) -> Dict[str, Any]:
"""
Gets a non initialized model similar to :epkg:`arnir0/Tiny-LLM`.
:param batch_size: batch size
:param sequence_length: sequence length
:param sequence_length2: new sequence length
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
:param use_static_cache: use StaticCache instead of DynamicCache
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
:return: dictionary
See :ref:`l-plot-tiny-llm-export` or :ref:`l-plot-tiny-llm-export-patched` for examples.
"""
from ...tasks.text_generation import get_inputs
config = {
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 192,
"initializer_range": 0.02,
"intermediate_size": 1024,
"max_position_embeddings": 1024,
"model_type": "llama",
"num_attention_heads": 2,
"num_hidden_layers": 1,
"num_key_value_heads": 1,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
"tie_word_embeddings": False,
"torch_dtype": "float32",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
}
config.update(**kwargs)
conf = transformers.LlamaConfig(**config)
if use_static_cache:
conf.cache_implementation = "static"
model = transformers.LlamaForCausalLM(conf)
model.eval()
res = get_inputs(
model,
conf,
dummy_max_token_id=config["vocab_size"], # type: ignore[arg-type]
num_hidden_layers=config["num_hidden_layers"], # type: ignore[arg-type]
batch_size=batch_size,
sequence_length=sequence_length,
sequence_length2=sequence_length2,
dynamic_rope=dynamic_rope,
num_key_value_heads=config["num_key_value_heads"], # type: ignore[arg-type]
cls_cache="StaticCache" if use_static_cache else "DynamicCache",
)
return dict(
inputs=res["inputs"],
model=model,
dynamic_shapes=res["dynamic_shapes"],
configuration=conf,
)