In a previous blog post, we demonstrated a streamlined implementation for running small language models (SLMs) with Key-Value Cache (KVC) on the ET-SoC-1 accelerator using ONNXRuntime. That implementation required only about 50 lines of code, showcasing how simple and efficient inference can be with the right tools and optimizations.
While the code is concise, a closer examination reveals room for optimization. Specifically, we identified inefficiencies during the preparation of inputs for the next token generation that lead to excessive host-device communication. This post dives deeper into these inefficiencies and presents a better approach using the ONNXRuntime IO-Bindings feature.
Revisiting the Initial Implementation
def llm_kvc_inference(session: onnxruntime.InferenceSession, tokenizer: AutoTokenizer, input_tensors: dict,
prompt_tensor, num_tokens, context, sequence_len, window: int, batch: int) -> tuple[str, float]:
sum_perplexity = 0
new_token = np.array([10])
prompt_size = len(prompt_tensor[0])
total_input = prompt_tensor
current_index = 0
next_index = window
inputs_names = [input.name for input in session.get_inputs()]
output_names = [output.name for output in session.get_outputs()]
# Run the inferences
while next_index < num_tokens:
if (new_token == tokenizer.eos_token_id).any():
break
output = session.run(output_names, input_tensors)
outs_dictionary = {name: content for (name, content) in zip(output_names, output)}
logits = outs_dictionary['logits']
# Prepare next inference inputs
for name in inputs_names:
if name == 'input_ids':
current_index = next_index
next_index = update_next_index(next_index, prompt_size, window)
j = next_index - window
if current_index >= prompt_size:
top1 = logits.argmax(-1)
new_token = top1.reshape(batch, window) # Infer next token
total_input = np.concatenate((total_input, new_token[:, -1:]), axis=1)
input_tensors['input_ids'] = total_input[:, j:next_index].reshape(batch, window)
elif name == 'attention_mask':
attention_mask = np.zeros((batch, sequence_len), dtype='int64')
attention_mask[:, -next_index:] = 1
input_tensors['attention_mask'] = attention_mask
else:
old_name = name.replace("past_key_values", "present")
start_idx = next_index - current_index
end_idx = start_idx + context - window
input_tensors[name] = outs_dictionary[old_name][:, :, start_idx:end_idx, :]
sum_perplexity += -np.log(softmax(logits[0, 0])[input_tensors['input_ids'][0]])
sum_perplexity /= (num_tokens - 1)
answers = tokenizer.batch_decode(total_input, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return answers, sum_perplexity
Here’s a quick breakdown of the initial implementation:
-
Inference Execution: Inference occurs at line 18, utilizing ONNXRuntime to generate the next token.
-
Next Token Input Preparation: Lines 19 through 44 focus on preparing the inputs for the next token generation. This preparation involves updating KVC tensors (“past”) with newly generated “present” values.
Since the input preparation is performed on the host, ONNXRuntime needs to:
-
Transfer “present” tensors from the device to the host.
-
Allow the host to prepare new “past” tensors.
-
Transfer the updated “past” tensors back to the device.
This process introduces extra latency overhead due to multiple device-host-device memory transfers (one per each model layer), significantly impacting performance.

Figure 1: Extra Overhead Added by Performing Next-Inference Preparations on the Host
In Figure 1, we can see that on average, each inference may spend approximately 103ms on the next token input preparation.
A Better Way with IO-Bindings
To overcome these inefficiencies, we can leverage ONNXRuntime’s IO-Bindings feature available in upstream. This powerful feature allows developers to directly bind input and output tensors to device-specific memory, eliminating unnecessary data transfers between the host and the device.
Here is the key insight: After an inference step, the “present” KVC tensors generated can be reused for subsequent iterations. By pre-allocating memory for the “past” and “present” tensors with sufficient padding, we can seamlessly implement a sliding window mechanism. This enables us to update pointers to the tensors without additional memory transfers, making the process highly efficient.
Implementing IO-Bindings
Below is a simplified code snippet demonstrating how to implement this optimization. First, we will introduce two new methods required by our new llm_kvc_inference_iobindings implementation.
To implement a sliding window mechanism, we need the present/past tensor to bind to the same underlying memory allocation. This is achieved with preallocate_input_outputs and creating an OrtValue that is big enough.
def preallocate_input_outputs(session : onnxruntime.InferenceSession, output_names: list[str], input_tensors: dict,
window: int, batch: int, nheads: int, context: int, hidden: int, logits_last_dim: int,
device_type="et", device_id=0) -> dict:
ortvalues = {
'input_ids': onnxruntime.OrtValue.ortvalue_from_numpy(input_tensors['input_ids'], device_type, device_id),
'attention_mask': onnxruntime.OrtValue.ortvalue_from_numpy(input_tensors['attention_mask'], device_type, device_id),
'logits': onnxruntime.OrtValue.ortvalue_from_shape_and_type((batch, window, logits_last_dim), TypeHelper.ort_type_to_numpy_type(TypeHelper.get_output_type(session, 'logits')), device_type, device_id),
}
# All KVC input (past_key_value) & output (present) tensors will share same underlying allocation.
zeros = np.zeros((batch, nheads, context, hidden), dtype=TypeHelper.ort_type_to_numpy_type(TypeHelper.get_output_type(session, 'present.0.key')))
zeros_padded = np.pad(zeros, ((0,0), (0,1), (0, 0), (0,0)), mode='constant')
for name in output_names:
if 'present' in name:
ortvalues[name] = onnxruntime.OrtValue.ortvalue_from_numpy(zeros_padded, device_type, device_id)
return ortvalues
Then present/past tensors are bound providing an appropriate device pointer (lines 15 & 26 below):
def bind_input_outputs(io_binding, inputs_names, outputs_names, ortvalues, batch : int, nheads : int, context : int, hidden : int):
# Bind inputs (will bind them to the allocated memory)
for name in inputs_names:
if name in ['input_ids', 'attention_mask']:
# For input_ids or attention_mask lets bind the ortvalue directly
io_binding.bind_ortvalue_input(name, ortvalues[name])
else:
# For 'past_key_value' we need to bind the buffer_ptr to the underlying allocation
out_name = name.replace("past_key_values", "present")
io_binding.bind_input(name,
device_type=ortvalues[out_name].device_name(),
device_id=0,
element_type=np.float16,
shape=(batch, nheads, context - 1, hidden),
buffer_ptr=ortvalues[out_name].data_ptr())
# Bind outputs
for name in outputs_names:
if 'logits' in name:
io_binding.bind_ortvalue_output(name, ortvalues[name])
else:
io_binding.bind_output(name,
device_type=ortvalues[name].device_name(),
device_id=0,
element_type=np.float16,
shape=(batch, nheads, context, hidden),
buffer_ptr=ortvalues[name].data_ptr())
Then we present the optimized version llm_kvc_inference_iobindings. Notice how all inputs are pre-allocated using IO-Bindings (line 20), then bound to the device (line 24), and that each inference is launched by session.run_with_iobindinginstead of session.run (line 34). Also, notice how next-token input inference preparation utilizes OrtValue update_inplace (line 51 & 55) for input_ids and attention_mask and a new method update_kvc_view is called for the rest.
def llm_kvc_inference_iobindings(session : onnxruntime.InferenceSession, run_options : onnxruntime.RunOptions,
tokenizer : AutoTokenizer, input_tensors : dict,
prompt_tensor, num_tokens, context, sequence_len, window : int, batch : int) -> tuple[str, float]:
sum_perplexity = 0
new_token = np.array([10])
prompt_size = len(prompt_tensor[0])
total_input = prompt_tensor
current_index = 0
next_index = window
logits_last_dim = int(session.get_outputs()[0].shape[-1])
hidden = int(session.get_outputs()[1].shape[-1])
nheads = int(session.get_outputs()[1].shape[-3])
inputs_names = [input.name for input in session.get_inputs()]
outputs_names = [output.name for output in session.get_outputs()]
# Pre-allocate inputs
ortvalues = preallocate_input_outputs(session, outputs_names, input_tensors, window, batch, nheads, context, hidden, logits_last_dim)
# Create IOBindings
io_binding = session.io_binding()
bind_input_outputs(io_binding, inputs_names, outputs_names, ortvalues, batch, nheads, context, hidden)
with ThreadPoolExecutor() as executor:
futures = []
# Run the inferences
while next_index < num_tokens:
if (new_token == tokenizer.eos_token_id).any():
break
session.run_with_iobinding(io_binding, run_options)
logits = ortvalues['logits'].numpy()
# Prepare next inference inputs
for name in inputs_names:
if name == 'input_ids':
current_index = next_index
next_index = update_next_index(next_index, prompt_size, window)
j = next_index - window
if current_index >= prompt_size:
top1 = logits.argmax(-1)
new_token = top1.reshape(batch, window) # Inf server
total_input = np.concatenate((total_input, new_token[: , -1:]), axis = 1)
next_input_ids = total_input[:, j:next_index]
ortvalues['input_ids'].update_inplace(np.ascontiguousarray(next_input_ids))
elif name == 'attention_mask':
attention_mask = np.zeros((batch, sequence_len), dtype='int64')
attention_mask[:, -next_index:] = 1
ortvalues['attention_mask'].update_inplace(attention_mask)
else:
update_kvc_view(io_binding, name, ortvalues, current_index, batch, nheads, context, hidden)
# Offload perplexity calculation to a separate thread
futures.append(executor.submit(compute_perplexity, logits[0, 0], next_input_ids[0]))
# Accumulate perplexity results from helper thread
for future in as_completed(futures):
sum_perplexity += future.result()
sum_perplexity /= (num_tokens - 1)
answers = tokenizer.batch_decode(total_input, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return answers, sum_perplexity
Finally in the update_kvc_view implementation, we can see how bind_input and bind_output are used to update the device pointer to the past/present tensors.
def update_kvc_view(io_binding, name, ortvalues, current_index, batch : int, nheads, context, hidden):
out_name = name.replace("past_key_values", "present")
last_dim_value = hidden
element_size_bytes = 2
num_bytes_per_cache_line = last_dim_value * element_size_bytes
offset_bytes = current_index * num_bytes_per_cache_line
buffer_ptr = ortvalues[out_name].data_ptr() + offset_bytes
device_type = ortvalues[out_name].device_name()
io_binding.bind_input(name,
device_type=device_type,
device_id=0,
element_type=np.float16,
shape=(batch, nheads, context - 1, hidden),
buffer_ptr=buffer_ptr)
io_binding.bind_output(out_name,
device_type=device_type,
device_id=0,
element_type=np.float16,
shape=(batch, nheads, context, hidden),
buffer_ptr=buffer_ptr)
Results
By implementing IO-Bindings and pre-allocating memory for KVC tensors, we observed significant performance improvements compared to the naive implementation:
Model: Esperanto/llama3-8b-Instruct-kvc-AWQ-int4-onnx · Hugging Face
Number of tokens requested: 120
Test question: “when summer starts?”
Naive Implementation Results
-
Inference Time: 19.19 seconds (5.21 tokens/s)
IO-Bindings Implementation Results
-
Inference Time: 8.68 seconds (11.51 tokens/s)
In addition, we can also compare the execution traces of both the naive implementation and the IO-Bindings Implementation.

Figure 2: Naive vs IO-Bindings Implementation
As shown in Figures 2 and 3, the IO-Bindings implementation completely eliminated the extra overhead of handling the past and present KVC tensors. In Figure 3, we can see a significant reduction in DKernel back-to-back latency after applying IO-Bindings.

Figure 3: DKernel Back-to-Back Latency
Furthermore, when measuring token generation latency, understood as the time between the completion of two DKernels (Figure 4), we observe an improvement from approximately 185ms to 76ms.

Figure 4: Token Generation Latency From Device
Conclusion
IO-Bindings unlock the full potential of ONNXRuntime by minimizing data transfer overhead and optimizing memory usage. When running small language models with KVC on accelerators, these techniques are essential for achieving high-performance inference. The results speak for themselves, with inference speed improving from 5.21 tokens/s to 11.51 tokens/s—a 2.2x improvement in throughput, and a 2.34x improvement in latency.
