In a previous blog post, we demonstrated a streamlined implementation for running small language models (SLMs) with Key-Value Cache (KVC) on the ET-SoC-1 accelerator using ONNXRuntime. That implementation required only about 50 lines of code, showcasing how simple and efficient inference can be with the right tools and optimizations.
While the code is concise, a closer examination reveals room for optimization. Specifically, we identified inefficiencies during the preparation of inputs for the next token generation that lead to excessive host-device communication. This post dives deeper into these inefficiencies and presents a better approach using the ONNXRuntime IO-Bindings feature.
Revisiting the Initial Implementation
def llm_kvc_inference(session: onnxruntime.InferenceSession, tokenizer: AutoTokenizer, input_tensors: dict, prompt_tensor, num_tokens, context, sequence_len, window: int, batch: int) -> tuple[str, float]: sum_perplexity = 0 new_token = np.array([10]) prompt_size = len(prompt_tensor[0]) total_input = prompt_tensor current_index = 0 next_index = window inputs_names = [input.name for input in session.get_inputs()] output_names = [output.name for output in session.get_outputs()] # Run the inferences while next_index < num_tokens: if (new_token == tokenizer.eos_token_id).any(): break output = session.run(output_names, input_tensors) outs_dictionary = {name: content for (name, content) in zip(output_names, output)} logits = outs_dictionary['logits'] # Prepare next inference inputs for name in inputs_names: if name == 'input_ids': current_index = next_index next_index = update_next_index(next_index, prompt_size, window) j = next_index - window if current_index >= prompt_size: top1 = logits.argmax(-1) new_token = top1.reshape(batch, window) # Infer next token total_input = np.concatenate((total_input, new_token[:, -1:]), axis=1) input_tensors['input_ids'] = total_input[:, j:next_index].reshape(batch, window) elif name == 'attention_mask': attention_mask = np.zeros((batch, sequence_len), dtype='int64') attention_mask[:, -next_index:] = 1 input_tensors['attention_mask'] = attention_mask else: old_name = name.replace("past_key_values", "present") start_idx = next_index - current_index end_idx = start_idx + context - window input_tensors[name] = outs_dictionary[old_name][:, :, start_idx:end_idx, :] sum_perplexity += -np.log(softmax(logits[0, 0])[input_tensors['input_ids'][0]]) sum_perplexity /= (num_tokens - 1) answers = tokenizer.batch_decode(total_input, skip_special_tokens=True, clean_up_tokenization_spaces=False) return answers, sum_perplexity
Here’s a quick breakdown of the initial implementation:
-
Inference Execution: Inference occurs at line 18, utilizing ONNXRuntime to generate the next token.
-
Next Token Input Preparation: Lines 19 through 44 focus on preparing the inputs for the next token generation. This preparation involves updating KVC tensors (“past”) with newly generated “present” values.
Since the input preparation is performed on the host, ONNXRuntime needs to:
-
Transfer “present” tensors from the device to the host.
-
Allow the host to prepare new “past” tensors.
-
Transfer the updated “past” tensors back to the device.
This process introduces extra latency overhead due to multiple device-host-device memory transfers (one per each model layer), significantly impacting performance.

Figure 1: Extra Overhead Added by Performing Next-Inference Preparations on the Host
In Figure 1, we can see that on average, each inference may spend approximately 103ms on the next token input preparation.
A Better Way with IO-Bindings
To overcome these inefficiencies, we can leverage ONNXRuntime’s IO-Bindings feature available in upstream. This powerful feature allows developers to directly bind input and output tensors to device-specific memory, eliminating unnecessary data transfers between the host and the device.
Here is the key insight: After an inference step, the “present” KVC tensors generated can be reused for subsequent iterations. By pre-allocating memory for the “past” and “present” tensors with sufficient padding, we can seamlessly implement a sliding window mechanism. This enables us to update pointers to the tensors without additional memory transfers, making the process highly efficient.
Implementing IO-Bindings
Below is a simplified code snippet demonstrating how to implement this optimization. First, we will introduce two new methods required by our new llm_kvc_inference_iobindings
implementation.
To implement a sliding window mechanism, we need the present/past tensor to bind to the same underlying memory allocation. This is achieved with preallocate_input_outputs
and creating an OrtValue that is big enough.
def preallocate_input_outputs(session : onnxruntime.InferenceSession, output_names: list[str], input_tensors: dict, window: int, batch: int, nheads: int, context: int, hidden: int, logits_last_dim: int, device_type="et", device_id=0) -> dict: ortvalues = { 'input_ids': onnxruntime.OrtValue.ortvalue_from_numpy(input_tensors['input_ids'], device_type, device_id), 'attention_mask': onnxruntime.OrtValue.ortvalue_from_numpy(input_tensors['attention_mask'], device_type, device_id), 'logits': onnxruntime.OrtValue.ortvalue_from_shape_and_type((batch, window, logits_last_dim), TypeHelper.ort_type_to_numpy_type(TypeHelper.get_output_type(session, 'logits')), device_type, device_id), } # All KVC input (past_key_value) & output (present) tensors will share same underlying allocation. zeros = np.zeros((batch, nheads, context, hidden), dtype=TypeHelper.ort_type_to_numpy_type(TypeHelper.get_output_type(session, 'present.0.key'))) zeros_padded = np.pad(zeros, ((0,0), (0,1), (0, 0), (0,0)), mode='constant') for name in output_names: if 'present' in name: ortvalues[name] = onnxruntime.OrtValue.ortvalue_from_numpy(zeros_padded, device_type, device_id) return ortvalues
Then present/past tensors are bound providing an appropriate device pointer (lines 15 & 26 below):
def bind_input_outputs(io_binding, inputs_names, outputs_names, ortvalues, batch : int, nheads : int, context : int, hidden : int): # Bind inputs (will bind them to the allocated memory) for name in inputs_names: if name in ['input_ids', 'attention_mask']: # For input_ids or attention_mask lets bind the ortvalue directly io_binding.bind_ortvalue_input(name, ortvalues[name]) else: # For 'past_key_value' we need to bind the buffer_ptr to the underlying allocation out_name = name.replace("past_key_values", "present") io_binding.bind_input(name, device_type=ortvalues[out_name].device_name(), device_id=0, element_type=np.float16, shape=(batch, nheads, context - 1, hidden), buffer_ptr=ortvalues[out_name].data_ptr()) # Bind outputs for name in outputs_names: if 'logits' in name: io_binding.bind_ortvalue_output(name, ortvalues[name]) else: io_binding.bind_output(name, device_type=ortvalues[name].device_name(), device_id=0, element_type=np.float16, shape=(batch, nheads, context, hidden), buffer_ptr=ortvalues[name].data_ptr())
Then we present the optimized version llm_kvc_inference_iobindings
. Notice how all inputs are pre-allocated using IO-Bindings (line 20), then bound to the device (line 24), and that each inference is launched by session.run_with_iobinding
instead of session.run
(line 34). Also, notice how next-token input inference preparation utilizes OrtValue update_inplace
(line 51 & 55) for input_ids
and attention_mask
and a new method update_kvc_view
is called for the rest.
def llm_kvc_inference_iobindings(session : onnxruntime.InferenceSession, run_options : onnxruntime.RunOptions, tokenizer : AutoTokenizer, input_tensors : dict, prompt_tensor, num_tokens, context, sequence_len, window : int, batch : int) -> tuple[str, float]: sum_perplexity = 0 new_token = np.array([10]) prompt_size = len(prompt_tensor[0]) total_input = prompt_tensor current_index = 0 next_index = window logits_last_dim = int(session.get_outputs()[0].shape[-1]) hidden = int(session.get_outputs()[1].shape[-1]) nheads = int(session.get_outputs()[1].shape[-3]) inputs_names = [input.name for input in session.get_inputs()] outputs_names = [output.name for output in session.get_outputs()] # Pre-allocate inputs ortvalues = preallocate_input_outputs(session, outputs_names, input_tensors, window, batch, nheads, context, hidden, logits_last_dim) # Create IOBindings io_binding = session.io_binding() bind_input_outputs(io_binding, inputs_names, outputs_names, ortvalues, batch, nheads, context, hidden) with ThreadPoolExecutor() as executor: futures = [] # Run the inferences while next_index < num_tokens: if (new_token == tokenizer.eos_token_id).any(): break session.run_with_iobinding(io_binding, run_options) logits = ortvalues['logits'].numpy() # Prepare next inference inputs for name in inputs_names: if name == 'input_ids': current_index = next_index next_index = update_next_index(next_index, prompt_size, window) j = next_index - window if current_index >= prompt_size: top1 = logits.argmax(-1) new_token = top1.reshape(batch, window) # Inf server total_input = np.concatenate((total_input, new_token[: , -1:]), axis = 1) next_input_ids = total_input[:, j:next_index] ortvalues['input_ids'].update_inplace(np.ascontiguousarray(next_input_ids)) elif name == 'attention_mask': attention_mask = np.zeros((batch, sequence_len), dtype='int64') attention_mask[:, -next_index:] = 1 ortvalues['attention_mask'].update_inplace(attention_mask) else: update_kvc_view(io_binding, name, ortvalues, current_index, batch, nheads, context, hidden) # Offload perplexity calculation to a separate thread futures.append(executor.submit(compute_perplexity, logits[0, 0], next_input_ids[0])) # Accumulate perplexity results from helper thread for future in as_completed(futures): sum_perplexity += future.result() sum_perplexity /= (num_tokens - 1) answers = tokenizer.batch_decode(total_input, skip_special_tokens=True, clean_up_tokenization_spaces=False) return answers, sum_perplexity
Finally in the update_kvc_view
implementation, we can see how bind_input
and bind_output
are used to update the device pointer to the past/present tensors.
def update_kvc_view(io_binding, name, ortvalues, current_index, batch : int, nheads, context, hidden): out_name = name.replace("past_key_values", "present") last_dim_value = hidden element_size_bytes = 2 num_bytes_per_cache_line = last_dim_value * element_size_bytes offset_bytes = current_index * num_bytes_per_cache_line buffer_ptr = ortvalues[out_name].data_ptr() + offset_bytes device_type = ortvalues[out_name].device_name() io_binding.bind_input(name, device_type=device_type, device_id=0, element_type=np.float16, shape=(batch, nheads, context - 1, hidden), buffer_ptr=buffer_ptr) io_binding.bind_output(out_name, device_type=device_type, device_id=0, element_type=np.float16, shape=(batch, nheads, context, hidden), buffer_ptr=buffer_ptr)
Results
By implementing IO-Bindings and pre-allocating memory for KVC tensors, we observed significant performance improvements compared to the naive implementation:
Model: Esperanto/llama3-8b-Instruct-kvc-AWQ-int4-onnx · Hugging Face
Number of tokens requested: 120
Test question: “when summer starts?”
Naive Implementation Results
-
Inference Time: 19.19 seconds (5.21 tokens/s)
IO-Bindings Implementation Results
-
Inference Time: 8.68 seconds (11.51 tokens/s)
In addition, we can also compare the execution traces of both the naive implementation and the IO-Bindings Implementation.

Figure 2: Naive vs IO-Bindings Implementation
As shown in Figures 2 and 3, the IO-Bindings implementation completely eliminated the extra overhead of handling the past and present KVC tensors. In Figure 3, we can see a significant reduction in DKernel back-to-back latency after applying IO-Bindings.

Figure 3: DKernel Back-to-Back Latency
Furthermore, when measuring token generation latency, understood as the time between the completion of two DKernels (Figure 4), we observe an improvement from approximately 185ms to 76ms.

Figure 4: Token Generation Latency From Device
Conclusion
IO-Bindings unlock the full potential of ONNXRuntime by minimizing data transfer overhead and optimizing memory usage. When running small language models with KVC on accelerators, these techniques are essential for achieving high-performance inference. The results speak for themselves, with inference speed improving from 5.21 tokens/s to 11.51 tokens/s—a 2.2x improvement in throughput, and a 2.34x improvement in latency.