In a previous blog post, we demonstrated a streamlined implementation for running small language models (SLMs) with Key-Value Cache (KVC) on the ET-SoC-1 accelerator using ONNXRuntime. That implementation required only about 50 lines of code, showcasing how simple and efficient inference can be with the right tools and optimizations.

While the code is concise, a closer examination reveals room for optimization. Specifically, we identified inefficiencies during the preparation of inputs for the next token generation that lead to excessive host-device communication. This post dives deeper into these inefficiencies and presents a better approach using the ONNXRuntime IO-Bindings feature.

Revisiting the Initial Implementation

def llm_kvc_inference(session: onnxruntime.InferenceSession, tokenizer: AutoTokenizer, input_tensors: dict,
                      prompt_tensor, num_tokens, context, sequence_len, window: int, batch: int) -> tuple[str, float]:
    sum_perplexity = 0
    new_token = np.array([10])
    prompt_size = len(prompt_tensor[0])
    total_input = prompt_tensor
    current_index = 0
    next_index = window

    inputs_names = [input.name for input in session.get_inputs()]
    output_names = [output.name for output in session.get_outputs()]

    # Run the inferences
    while next_index < num_tokens:
        if (new_token == tokenizer.eos_token_id).any():
            break

        output = session.run(output_names, input_tensors)
        outs_dictionary = {name: content for (name, content) in zip(output_names, output)}

        logits = outs_dictionary['logits']

        # Prepare next inference inputs
        for name in inputs_names:
            if name == 'input_ids':
                current_index = next_index
                next_index = update_next_index(next_index, prompt_size, window)
                j = next_index - window

                if current_index >= prompt_size:
                    top1 = logits.argmax(-1)
                    new_token = top1.reshape(batch, window)  # Infer next token
                    total_input = np.concatenate((total_input, new_token[:, -1:]), axis=1)

                input_tensors['input_ids'] = total_input[:, j:next_index].reshape(batch, window)
            elif name == 'attention_mask':
                attention_mask = np.zeros((batch, sequence_len), dtype='int64')
                attention_mask[:, -next_index:] = 1
                input_tensors['attention_mask'] = attention_mask
            else:
                old_name = name.replace("past_key_values", "present")
                start_idx = next_index - current_index
                end_idx = start_idx + context - window
                input_tensors[name] = outs_dictionary[old_name][:, :, start_idx:end_idx, :]

        sum_perplexity += -np.log(softmax(logits[0, 0])[input_tensors['input_ids'][0]])

    sum_perplexity /= (num_tokens - 1)
    answers = tokenizer.batch_decode(total_input, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return answers, sum_perplexity

Here’s a quick breakdown of the initial implementation:

  • Inference Execution: Inference occurs at line 18, utilizing ONNXRuntime to generate the next token.

  • Next Token Input Preparation: Lines 19 through 44 focus on preparing the inputs for the next token generation. This preparation involves updating KVC tensors (“past”) with newly generated “present” values.

Since the input preparation is performed on the host, ONNXRuntime needs to:

  1. Transfer “present” tensors from the device to the host.

  2. Allow the host to prepare new “past” tensors.

  3. Transfer the updated “past” tensors back to the device.

This process introduces extra latency overhead due to multiple device-host-device memory transfers (one per each model layer), significantly impacting performance.

Figure 1: Extra Overhead Added by Performing Next-Inference Preparations on the Host

In Figure 1, we can see that on average, each inference may spend approximately 103ms on the next token input preparation.

A Better Way with IO-Bindings

To overcome these inefficiencies, we can leverage ONNXRuntime’s IO-Bindings feature available in upstream. This powerful feature allows developers to directly bind input and output tensors to device-specific memory, eliminating unnecessary data transfers between the host and the device.

Here is the key insight: After an inference step, the “present” KVC tensors generated can be reused for subsequent iterations. By pre-allocating memory for the “past” and “present” tensors with sufficient padding, we can seamlessly implement a sliding window mechanism. This enables us to update pointers to the tensors without additional memory transfers, making the process highly efficient.

Implementing IO-Bindings

Below is a simplified code snippet demonstrating how to implement this optimization. First, we will introduce two new methods required by our new llm_kvc_inference_iobindings implementation.

To implement a sliding window mechanism, we need the present/past tensor to bind to the same underlying memory allocation. This is achieved with preallocate_input_outputs and creating an OrtValue that is big enough.

def preallocate_input_outputs(session : onnxruntime.InferenceSession, output_names: list[str], input_tensors: dict,
                              window: int, batch: int, nheads: int, context: int, hidden: int, logits_last_dim: int,
                              device_type="et", device_id=0) -> dict:
    ortvalues = {
        'input_ids':      onnxruntime.OrtValue.ortvalue_from_numpy(input_tensors['input_ids'], device_type, device_id),
        'attention_mask': onnxruntime.OrtValue.ortvalue_from_numpy(input_tensors['attention_mask'], device_type, device_id),
        'logits':      onnxruntime.OrtValue.ortvalue_from_shape_and_type((batch, window, logits_last_dim), TypeHelper.ort_type_to_numpy_type(TypeHelper.get_output_type(session, 'logits')), device_type, device_id),
    }
    # All KVC input (past_key_value) & output (present) tensors will share same underlying allocation.
    zeros = np.zeros((batch, nheads, context, hidden), dtype=TypeHelper.ort_type_to_numpy_type(TypeHelper.get_output_type(session, 'present.0.key')))
    zeros_padded = np.pad(zeros, ((0,0), (0,1), (0, 0), (0,0)), mode='constant')
    for name in output_names:
        if 'present' in name:
            ortvalues[name] = onnxruntime.OrtValue.ortvalue_from_numpy(zeros_padded, device_type, device_id)
    return ortvalues

Then present/past tensors are bound providing an appropriate device pointer (lines 15 & 26 below):

def bind_input_outputs(io_binding, inputs_names, outputs_names, ortvalues, batch : int, nheads : int, context : int, hidden : int):
    # Bind inputs (will bind them to the allocated memory)
    for name in inputs_names:
        if name in ['input_ids', 'attention_mask']:
            # For input_ids or attention_mask lets bind the ortvalue directly
            io_binding.bind_ortvalue_input(name, ortvalues[name])
        else:
            # For 'past_key_value' we need to bind the buffer_ptr to the underlying allocation
            out_name = name.replace("past_key_values", "present")
            io_binding.bind_input(name,
                                  device_type=ortvalues[out_name].device_name(),
                                  device_id=0,
                                  element_type=np.float16,
                                  shape=(batch, nheads, context - 1, hidden),
                                  buffer_ptr=ortvalues[out_name].data_ptr())
    # Bind outputs
    for name in outputs_names:
        if 'logits' in name:
            io_binding.bind_ortvalue_output(name, ortvalues[name])
        else:
            io_binding.bind_output(name,
                                   device_type=ortvalues[name].device_name(),
                                   device_id=0,
                                   element_type=np.float16,
                                   shape=(batch, nheads, context, hidden),
                                   buffer_ptr=ortvalues[name].data_ptr())

Then we present the optimized version llm_kvc_inference_iobindings. Notice how all inputs are pre-allocated using IO-Bindings (line 20), then bound to the device (line 24), and that each inference is launched by session.run_with_iobindinginstead of session.run (line 34). Also, notice how next-token input inference preparation utilizes OrtValue update_inplace (line 51 & 55) for input_ids and attention_mask and a new method update_kvc_view is called for the rest.

def llm_kvc_inference_iobindings(session : onnxruntime.InferenceSession, run_options : onnxruntime.RunOptions,
                                 tokenizer : AutoTokenizer, input_tensors : dict,
                                 prompt_tensor, num_tokens, context, sequence_len, window : int, batch : int) -> tuple[str, float]:

    sum_perplexity = 0
    new_token = np.array([10])
    prompt_size = len(prompt_tensor[0])
    total_input = prompt_tensor
    current_index = 0
    next_index = window
    
    logits_last_dim = int(session.get_outputs()[0].shape[-1])
    hidden = int(session.get_outputs()[1].shape[-1])
    nheads = int(session.get_outputs()[1].shape[-3])

    inputs_names = [input.name for input in session.get_inputs()]
    outputs_names = [output.name for output in session.get_outputs()]

    # Pre-allocate inputs
    ortvalues = preallocate_input_outputs(session, outputs_names, input_tensors, window, batch, nheads, context, hidden, logits_last_dim)

    # Create IOBindings
    io_binding = session.io_binding()
    bind_input_outputs(io_binding, inputs_names, outputs_names, ortvalues, batch, nheads, context, hidden)

    with ThreadPoolExecutor() as executor:
        futures = []

        # Run the inferences
        while next_index < num_tokens:
            if (new_token == tokenizer.eos_token_id).any():
                break

            session.run_with_iobinding(io_binding, run_options)

            logits = ortvalues['logits'].numpy()

            # Prepare next inference inputs
            for name in inputs_names:
                if name == 'input_ids':
                    current_index = next_index
                    next_index = update_next_index(next_index, prompt_size, window)
                    j = next_index - window

                    if current_index >= prompt_size:
                        top1 = logits.argmax(-1)
                        new_token = top1.reshape(batch, window) # Inf server
                        total_input = np.concatenate((total_input, new_token[: , -1:]), axis = 1)

                    next_input_ids = total_input[:, j:next_index]
                    ortvalues['input_ids'].update_inplace(np.ascontiguousarray(next_input_ids))
                elif name == 'attention_mask':
                    attention_mask = np.zeros((batch, sequence_len), dtype='int64')
                    attention_mask[:, -next_index:] = 1
                    ortvalues['attention_mask'].update_inplace(attention_mask)
                else:
                    update_kvc_view(io_binding, name, ortvalues, current_index, batch, nheads, context, hidden)

            # Offload perplexity calculation to a separate thread
            futures.append(executor.submit(compute_perplexity, logits[0, 0], next_input_ids[0]))

        # Accumulate perplexity results from helper thread
        for future in as_completed(futures):
            sum_perplexity += future.result()

    sum_perplexity /= (num_tokens - 1)
    answers = tokenizer.batch_decode(total_input, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return answers, sum_perplexity

Finally in the update_kvc_view implementation, we can see how bind_input and bind_output are used to update the device pointer to the past/present tensors.

def update_kvc_view(io_binding, name, ortvalues, current_index, batch : int, nheads, context, hidden):
    out_name = name.replace("past_key_values", "present")
    last_dim_value = hidden
    element_size_bytes = 2
    num_bytes_per_cache_line = last_dim_value * element_size_bytes
    offset_bytes = current_index * num_bytes_per_cache_line
    buffer_ptr = ortvalues[out_name].data_ptr() + offset_bytes
    device_type = ortvalues[out_name].device_name()
    io_binding.bind_input(name,
                          device_type=device_type,
                          device_id=0,
                          element_type=np.float16,
                          shape=(batch, nheads, context - 1, hidden),
                          buffer_ptr=buffer_ptr)
    io_binding.bind_output(out_name,
                           device_type=device_type,
                           device_id=0,
                           element_type=np.float16,
                           shape=(batch, nheads, context, hidden),
                           buffer_ptr=buffer_ptr)

Results

By implementing IO-Bindings and pre-allocating memory for KVC tensors, we observed significant performance improvements compared to the naive implementation:

Model: Esperanto/llama3-8b-Instruct-kvc-AWQ-int4-onnx · Hugging Face

Number of tokens requested: 120

Test question: “when summer starts?”

Naive Implementation Results

  • Inference Time: 19.19 seconds (5.21 tokens/s)

IO-Bindings Implementation Results

  • Inference Time: 8.68 seconds (11.51 tokens/s)

In addition, we can also compare the execution traces of both the naive implementation and the IO-Bindings Implementation.

Figure 2: Naive vs IO-Bindings Implementation

As shown in Figures 2 and 3, the IO-Bindings implementation completely eliminated the extra overhead of handling the past and present KVC tensors. In Figure 3, we can see a significant reduction in DKernel back-to-back latency after applying IO-Bindings.

Figure 3: DKernel Back-to-Back Latency

Furthermore, when measuring token generation latency, understood as the time between the completion of two DKernels (Figure 4), we observe an improvement from approximately 185ms to 76ms.

Figure 4: Token Generation Latency From Device

Conclusion

IO-Bindings unlock the full potential of ONNXRuntime by minimizing data transfer overhead and optimizing memory usage. When running small language models with KVC on accelerators, these techniques are essential for achieving high-performance inference. The results speak for themselves, with inference speed improving from 5.21 tokens/s to 11.51 tokens/s—a 2.2x improvement in throughput, and a 2.34x improvement in latency.