As AI models grow in complexity, efficiently running small language models (SLMs) is increasingly important for developers, especially in resource-constrained environments. In a previous blog post we explained that Esperanto Technologies is interested in running SLMs in the ONNX format. These models are first converted from PyTorch or TensorFlow to ONNX. You can explore open-source model conversions on Hugging Face at this link.
To execute the models, we use ONNXRuntime, an open-source framework designed for optimizing and running machine learning models that has become a key player in this space. Recently, Esperanto Technologies introduced a custom execution backend for ONNXRuntime, enabling developers to leverage the ET-SoC-1 hardware accelerator with ONNXRuntime APIs (C, C++, Python, etc). This integration makes it easier to compile and execute models efficiently, making it a great option for running small language models.
In this blog, we will explore how to use Esperanto’s EtGlowExecutionProvider
and dive into the specific provider options for compiling and running models on Esperanto’s hardware. We will also walk through an example implementation of running an SLM with Key-Value Cache (KVC), showing how simple it is to switch between running on a CPU and running on Esperanto’s accelerator.
Running Models with Esperanto's ONNXRuntime Provider
To run models with Esperanto’s backend, the key step is to initialize the ONNXRuntime inference session using the EtGlowExecutionProvider
. This allows the model to be compiled and executed using Esperanto’s hardware accelerators, providing both performance and power efficiency.
Here’s a basic example of how to initialize an ONNXRuntime inference session with Esperanto’s provider:
import onnxruntime as ort # Define provider options to configure the accelerator provider_options = { "et_onnx_shape_params": "batch=2;sequence_len=256", "et_greedy": "true" } # Initialize the ONNX Runtime session with EtGlowExecutionProvider session = ort.InferenceSession( "model.onnx", providers=["EtExecutionProvider"], provider_options=[provider_options] ) # Run inference with the accelerator input_data = { "input": some_input } output = session.run(None, input_data)
Key Provider Options
Esperanto Technologies’ EtGlowExecutionProvider
offers several provider options that allow developers to control how models are compiled and executed. These options help to ensure that models are compiled efficiently, especially when dealing with dynamic shapes or fine-tuning compiler behavior. Let’s take a closer look:
1.etglow_onnx_shape_params
: Handling Parametric Shapes
If your model has dynamic shapes, such as varying batch sizes or sequence lengths, you need to provide the etglow_onnx_shape_params
option. This is crucial because the Esperanto compiler performs model compilation at the time of InferenceSession
construction, and all parametric shapes must be resolved before compilation can begin.
The shape parameters are specified as key-value pairs, separated by semicolons. For example:
provider_options = { "etglow_onnx_shape_params": "batch=2;sequence_len=256" }
This ensures that the model is compiled with the provided batch size (2) and sequence length (256).
2. etglow_api_params
: Fine-Tuning Compiler Behavior
The etglow_api_params
option allows developers to adjust internal compiler settings to fine-tune compiler behavior. This can include setting the number of threads for compilation, specifying directories for temporary compilation files among others. For example:
provider_options = { "etglow_api_params": "glow-threads=2;runDir=/tmp/compilations" }
glow-threads=2
sets the compiler to use two threads during compilation.runDir=/tmp/compilations
directs temporary files to a specific directory for management.
3. etglow_greedy
: Full Graph Ingestion
By default, the EtGlowExecutionProvider
inspects the ONNX model graph to determine which operators are supported by Esperanto’s compiler. ONNXRuntime partitions the graph, assigning supported nodes to the accelerator and unsupported nodes to other accelerators (or the CPU). However, this heuristic inspection is not always perfect, so the etglow_greedy
option allows you to bypass this step and send the entire model to Esperanto’s compiler, avoiding splitting the model execution between the host and the device.
provider_options = { "etglow_greedy": "true" }
This forces the entire model to be compiled and executed on Esperanto’s hardware, ensuring full graph ingestion.
Example: Running a Small Language Model with Key-Value Cache (KVC)
To demonstrate how to run an SLM with KVC using ONNXRuntime, let’s look at a simple Python implementation. The implementation shown below is capable of running small language models with a Key-Value Cache (KVC), allowing the model to reuse previous computations, which is critical for optimizing sequential inferences such as text generation.
def llm_kvc_inference(session: onnxruntime.InferenceSession, tokenizer: AutoTokenizer, input_tensors: dict, prompt_tensor, num_tokens, context, sequence_len, window: int, batch: int) -> tuple[str, float]: sum_perplexity = 0 new_token = np.array([10]) prompt_size = len(prompt_tensor[0]) total_input = prompt_tensor current_index = 0 next_index = window inputs_names = [input.name for input in session.get_inputs()] output_names = [output.name for output in session.get_outputs()] # Run the inferences while next_index < num_tokens: if (new_token == tokenizer.eos_token_id).any(): break output = session.run(output_names, input_tensors) outs_dictionary = {name: content for (name, content) in zip(output_names, output)} logits = outs_dictionary['logits'] # Prepare next inference inputs for name in inputs_names: if name == 'input_ids': current_index = next_index next_index = update_next_index(next_index, prompt_size, window) j = next_index - window if current_index >= prompt_size: top1 = logits.argmax(-1) new_token = top1.reshape(batch, window) # Infer next token total_input = np.concatenate((total_input, new_token[:, -1:]), axis=1) input_tensors['input_ids'] = total_input[:, j:next_index].reshape(batch, window) elif name == 'attention_mask': attention_mask = np.zeros((batch, sequence_len), dtype='int64') attention_mask[:, -next_index:] = 1 input_tensors['attention_mask'] = attention_mask else: old_name = name.replace("past_key_values", "present") start_idx = next_index - current_index end_idx = start_idx + context - window input_tensors[name] = outs_dictionary[old_name][:, :, start_idx:end_idx, :] sum_perplexity += -np.log(softmax(logits[0, 0])[input_tensors['input_ids'][0]]) sum_perplexity /= (num_tokens - 1) answers = tokenizer.batch_decode(total_input, skip_special_tokens=True, clean_up_tokenization_spaces=False) return answers, sum_perplexity
To know more about KVC models, please read our previous blog.
Switching Between CPU and ET-SoC-1
Running ONNX models on a CPU or on Esperanto’s accelerator (ET-SoC-1
) requires only a small change in the initialization of the InferenceSession
. For CPU execution, the session is initialized like this:
session_cpu = onnxruntime.InferenceSession(model_path, sess_options, providers=['CPUExecutionProvider'])
For execution on the Esperanto accelerator, use:
provider_options = get_etglow_provider_options() session_etglow = onnxruntime.InferenceSession(model_path, sess_options, providers=['EtGlowExecutionProvider'], provider_options=[provider_options])
With just a small tweak, you can switch between running the model on the CPU or the EtSoC-1 accelerator, offering flexibility depending on your hardware and performance needs.
Results
By using ONNXRuntime, we can easily make apples-to-apples comparisons between the CPU Execution Provider and EtGlow Execution Provider.
Model: Esperanto/llama3-8b-Instruct-kvc-AWQ-int4-onnx
Number of tokens requested: 120
Test question: “when summer starts?”
CPU (Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz):
- Inference Time: 159.64 seconds (0.63 tokens/s)
ET-SoC-1 Naive implementation:
- Inference Time: 19.19 seconds (5.21 tokens/s)
Conclusion
Esperanto Technologies’ custom execution backend for ONNXRuntime provides a simple way to accelerate small language models on ET-SoC-1 accelerator. The combination of flexible provider options like etglow_onnx_shape_params
, etglow_api_params
, and etglow_greedy
offers developers control over the model compilation process, ensuring efficient execution.
By following the steps in this guide, you can easily run SLMs switching between CPU execution and Esperanto’s hardware as needed. This flexibility unlocks new possibilities for deploying high-performance AI models in a wide range of environments.
By utilizing EtGlowExecutionProvider
, users can achieve remarkable speedups compared to CPU execution (for instance, 8.26x for the current example) — and this is with the naive implementation.
In future blogs, we will explain how to get traces with ONNXRuntime and EtGlowExecutionProvider
to visualize the performance of this implementation. Additionally, we will explore optimizations and techniques to further enhance this performance.