diff --git a/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/api/VimKeyGroupBase.kt b/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/api/VimKeyGroupBase.kt
index 456fffa13..08d5ca754 100644
--- a/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/api/VimKeyGroupBase.kt
+++ b/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/api/VimKeyGroupBase.kt
@@ -50,7 +50,7 @@ abstract class VimKeyGroupBase : VimKeyGroup {
   }
 
   override fun getKeyMapping(mode: MappingMode): KeyMapping {
-    return keyMappings.getOrPut(mode) { KeyMapping() }
+    return keyMappings.getOrPut(mode) { KeyMapping(mode.name[0].lowercase() + "map") }
   }
 
   override fun resetKeyMappings() {
diff --git a/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/command/MappingProcessor.kt b/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/command/MappingProcessor.kt
index 0c9547093..e58954611 100644
--- a/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/command/MappingProcessor.kt
+++ b/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/command/MappingProcessor.kt
@@ -21,6 +21,7 @@ import com.maddyhome.idea.vim.impl.state.toMappingMode
 import com.maddyhome.idea.vim.key.KeyConsumer
 import com.maddyhome.idea.vim.key.KeyMappingLayer
 import com.maddyhome.idea.vim.key.MappingInfoLayer
+import com.maddyhome.idea.vim.key.isPrefix
 import com.maddyhome.idea.vim.state.KeyHandlerState
 import javax.swing.KeyStroke
 
@@ -94,7 +95,7 @@ object MappingProcessor: KeyConsumer {
     // unless a sequence is also a prefix for another mapping. We eagerly evaluate the shortest mapping, so even if a
     // mapping is a prefix, it will get evaluated when the next character is entered.
     // Note that currentlyUnhandledKeySequence is the same as the state after commandState.getMappingKeys().add(key). It
-    // would be nice to tidy ths up
+    // would be nice to tidy this up
     if (!mapping.isPrefix(processBuilder.state.mappingState.keys)) {
       log.debug("There are no mappings that start with the current sequence. Returning false.")
       return false
@@ -161,7 +162,7 @@ object MappingProcessor: KeyConsumer {
     log.trace("Processing complete mapping sequence...")
     // The current sequence isn't a prefix, check to see if it's a completed sequence.
     val mappingState = processBuilder.state.mappingState
-    val currentMappingInfo = mapping.getLayer(mappingState.keys)
+    val currentMappingInfo = mapping.getLayer(mappingState.keys.toList())
     var mappingInfo = currentMappingInfo
     if (mappingInfo == null) {
       log.trace("Haven't found any mapping info for the given sequence. Trying to apply mapping to a subsequence.")
diff --git a/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyMapping.kt b/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyMapping.kt
index 44864f3a1..bde80cad8 100644
--- a/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyMapping.kt
+++ b/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyMapping.kt
@@ -10,8 +10,7 @@ package com.maddyhome.idea.vim.key
 import com.maddyhome.idea.vim.api.injector
 import com.maddyhome.idea.vim.extension.ExtensionHandler
 import com.maddyhome.idea.vim.vimscript.model.expressions.Expression
-import java.util.function.Consumer
-import java.util.stream.Collectors
+import org.jetbrains.annotations.TestOnly
 import javax.swing.KeyStroke
 
 /**
@@ -20,36 +19,37 @@ import javax.swing.KeyStroke
  *
  * @author vlan
  */
-class KeyMapping : Iterable<List<KeyStroke?>?>, KeyMappingLayer {
-  /**
-   * Contains all key mapping for some mode.
-   */
-  private val myKeys: MutableMap<List<KeyStroke>, MappingInfo> = HashMap()
+class KeyMapping(name: String) : Iterable<List<KeyStroke>>, KeyMappingLayer {
+  private val keysTrie = KeyStrokeTrie<MappingInfo>(name)
 
-  /**
-   * Set the contains all possible prefixes for mappings.
-   * E.g. if there is mapping for "hello", this set will contain "h", "he", "hel", etc.
-   * Multiset is used to correctly remove the mappings.
-   */
-  private val myPrefixes: MutableMap<List<KeyStroke>, Int> = HashMap()
-  override fun iterator(): MutableIterator<List<KeyStroke>> {
-    return ArrayList(myKeys.keys).iterator()
+  override fun iterator(): Iterator<List<KeyStroke>> = ArrayList(keysTrie.getAll().keys).iterator()
+
+  operator fun get(keys: List<KeyStroke>): MappingInfo? {
+    keysTrie.getData(keys)?.let { return it }
+
+    getActionNameFromActionMapping(keys)?.let {
+      return ToActionMappingInfo(it, keys, false, MappingOwner.IdeaVim.System)
+    }
+
+    return null
   }
 
+  @Deprecated("Use get(List<KeyStroke>)")
   operator fun get(keys: Iterable<KeyStroke>): MappingInfo? {
-    // Having a parameter of Iterable allows for a nicer API, because we know when a given list is immutable.
-    // TODO: Should we change this to be a trie?
-    assert(keys is List<*>) { "keys must be of type List<KeyStroke>" }
-    val keyStrokes = keys as List<KeyStroke>
-    val mappingInfo = myKeys[keys]
-    if (mappingInfo != null) return mappingInfo
-    if (keyStrokes.size > 3) {
-      if (keyStrokes[0].keyCode == injector.parser.actionKeyStroke.keyCode && keyStrokes[1].keyChar == '(' && keyStrokes[keyStrokes.size - 1].keyChar == ')') {
-        val builder = StringBuilder()
-        for (i in 2 until keyStrokes.size - 1) {
-          builder.append(keyStrokes[i].keyChar)
+    if (keys is List<KeyStroke>) {
+      return get(keys)
+    }
+    return get(keys.toList())
+  }
+
+  private fun getActionNameFromActionMapping(keys: List<KeyStroke>): String? {
+    if (keys.size > 3
+      && keys[0].keyCode == injector.parser.actionKeyStroke.keyCode
+      && keys[1].keyChar == '(' && keys.last().keyChar == ')') {
+      return buildString {
+        for (i in 2 until keys.size - 1) {
+          append(keys[i].keyChar)
         }
-        return ToActionMappingInfo(builder.toString(), keyStrokes, false, MappingOwner.IdeaVim.System)
       }
     }
     return null
@@ -61,8 +61,7 @@ class KeyMapping : Iterable<List<KeyStroke?>?>, KeyMappingLayer {
     extensionHandler: ExtensionHandler,
     recursive: Boolean,
   ) {
-    myKeys[ArrayList(fromKeys)] = ToHandlerMappingInfo(extensionHandler, fromKeys, recursive, owner)
-    fillPrefixes(fromKeys)
+    add(fromKeys, ToHandlerMappingInfo(extensionHandler, fromKeys, recursive, owner))
   }
 
   fun put(
@@ -71,8 +70,7 @@ class KeyMapping : Iterable<List<KeyStroke?>?>, KeyMappingLayer {
     owner: MappingOwner,
     recursive: Boolean,
   ) {
-    myKeys[ArrayList(fromKeys)] = ToKeysMappingInfo(toKeys, fromKeys, recursive, owner)
-    fillPrefixes(fromKeys)
+    add(fromKeys, ToKeysMappingInfo(toKeys, fromKeys, recursive, owner))
   }
 
   fun put(
@@ -82,104 +80,57 @@ class KeyMapping : Iterable<List<KeyStroke?>?>, KeyMappingLayer {
     originalString: String,
     recursive: Boolean,
   ) {
-    myKeys[ArrayList(fromKeys)] =
-      ToExpressionMappingInfo(toExpression, fromKeys, recursive, owner, originalString)
-    fillPrefixes(fromKeys)
+    add(fromKeys, ToExpressionMappingInfo(toExpression, fromKeys, recursive, owner, originalString))
   }
 
-  private fun fillPrefixes(fromKeys: List<KeyStroke>) {
-    val prefix: MutableList<KeyStroke> = ArrayList()
-    val prefixLength = fromKeys.size - 1
-    for (i in 0 until prefixLength) {
-      prefix.add(fromKeys[i])
-      myPrefixes[ArrayList(prefix)] = (myPrefixes[ArrayList(prefix)] ?: 0) + 1
-    }
+  private fun add(keys: List<KeyStroke>, mappingInfo: MappingInfo) {
+    keysTrie.add(keys, mappingInfo)
   }
 
   fun delete(owner: MappingOwner) {
-    val toRemove = myKeys.entries.stream()
-      .filter { (_, value): Map.Entry<List<KeyStroke>, MappingInfo> -> value.owner == owner }
-      .collect(Collectors.toList())
-    toRemove.forEach(
-      Consumer { (key, value): Map.Entry<List<KeyStroke>, MappingInfo> ->
-        myKeys.remove(
-          key,
-          value,
-        )
-      },
-    )
-    toRemove.map { it.key }.forEach(this::removePrefixes)
-  }
-
-  fun delete(keys: List<KeyStroke>) {
-    myKeys.remove(keys) ?: return
-    removePrefixes(keys)
-  }
-
-  fun delete() {
-    myKeys.clear()
-    myPrefixes.clear()
-  }
-
-  private fun removePrefixes(keys: List<KeyStroke>) {
-    val prefix: MutableList<KeyStroke> = ArrayList()
-    val prefixLength = keys.size - 1
-    for (i in 0 until prefixLength) {
-      prefix.add(keys[i])
-      val existingCount = myPrefixes[prefix]
-      if (existingCount == 1 || existingCount == null) {
-        myPrefixes.remove(prefix)
-      } else {
-        myPrefixes[prefix] = existingCount - 1
-      }
+    getByOwner(owner).forEach { (keys, _) ->
+      keysTrie.remove(keys)
     }
   }
 
-  fun getByOwner(owner: MappingOwner): List<Pair<List<KeyStroke>, MappingInfo>> {
-    return myKeys.entries.stream()
-      .filter { (_, value): Map.Entry<List<KeyStroke>, MappingInfo> -> value.owner == owner }
-      .map { (key, value): Map.Entry<List<KeyStroke>, MappingInfo> ->
-        Pair(
-          key,
-          value,
-        )
-      }.collect(Collectors.toList())
+  fun delete(keys: List<KeyStroke>) {
+    keysTrie.remove(keys)
   }
 
-  override fun isPrefix(keys: Iterable<KeyStroke>): Boolean {
-    // Having a parameter of Iterable allows for a nicer API, because we know when a given list is immutable.
-    // Perhaps we should look at changing this to a trie or something?
-    assert(keys is List<*>) { "keys must be of type List<KeyStroke>" }
-    val keyList = keys as List<KeyStroke>
-    if (keyList.isEmpty()) return false
-    if (myPrefixes.contains(keys)) return true
-    val firstChar = keyList[0].keyCode
-    val lastChar = keyList[keyList.size - 1].keyChar
+  fun delete() {
+    keysTrie.clear()
+  }
+
+  fun getByOwner(owner: MappingOwner): List<Pair<List<KeyStroke>, MappingInfo>> =
+    buildList {
+      keysTrie.getAll().forEach { (keys, mappingInfo) ->
+        if (mappingInfo.owner == owner) {
+          add(Pair(keys, mappingInfo))
+        }
+      }
+    }
+
+  override fun isPrefix(keys: List<KeyStroke>): Boolean {
+    if (keys.isEmpty()) return false
+
+    if (keysTrie.isPrefix(keys)) return true
+
+    val firstChar = keys.first().keyCode
+    val lastChar = keys.last().keyChar
     return firstChar == injector.parser.actionKeyStroke.keyCode && lastChar != ')'
   }
 
-  fun hasmapto(toKeys: List<KeyStroke?>): Boolean {
-    return myKeys.values.stream()
-      .anyMatch { o: MappingInfo? -> o is ToKeysMappingInfo && o.toKeys == toKeys }
+  fun hasmapto(toKeys: List<KeyStroke>) = keysTrie.getAll().any { (_, mappingInfo) ->
+    mappingInfo is ToKeysMappingInfo && mappingInfo.toKeys == toKeys
   }
 
-  fun hasmapfrom(fromKeys: List<KeyStroke?>): Boolean {
-    return myKeys.values.stream()
-      .anyMatch { o: MappingInfo? -> o is ToKeysMappingInfo && o.fromKeys == fromKeys }
-  }
+  fun hasmapfrom(fromKeys: List<KeyStroke>) = keysTrie.getData(fromKeys) != null
 
-  fun getMapTo(toKeys: List<KeyStroke?>): List<Pair<List<KeyStroke>, MappingInfo>> {
-    return myKeys.entries.stream()
-      .filter { (_, value): Map.Entry<List<KeyStroke>, MappingInfo> -> value is ToKeysMappingInfo && value.toKeys == toKeys }
-      .map { (key, value): Map.Entry<List<KeyStroke>, MappingInfo> ->
-        Pair(
-          key,
-          value,
-        )
-      }.collect(Collectors.toList())
-  }
+  @TestOnly
+  fun getMapTo(toKeys: List<KeyStroke?>) =
+    keysTrie.getAll().filter { (_, mappingInfo) ->
+      mappingInfo is ToKeysMappingInfo && mappingInfo.toKeys == toKeys
+    }.map { it.toPair() }
 
-  override fun getLayer(keys: Iterable<KeyStroke>): MappingInfoLayer? {
-    return get(keys)
-  }
+  override fun getLayer(keys: List<KeyStroke>): MappingInfoLayer? = get(keys)
 }
diff --git a/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyMappingLayer.kt b/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyMappingLayer.kt
index e14e78c89..7d7c96b8f 100644
--- a/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyMappingLayer.kt
+++ b/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyMappingLayer.kt
@@ -11,6 +11,13 @@ package com.maddyhome.idea.vim.key
 import javax.swing.KeyStroke
 
 interface KeyMappingLayer {
-  fun isPrefix(keys: Iterable<KeyStroke>): Boolean
-  fun getLayer(keys: Iterable<KeyStroke>): MappingInfoLayer?
+  fun isPrefix(keys: List<KeyStroke>): Boolean
+  fun getLayer(keys: List<KeyStroke>): MappingInfoLayer?
+}
+
+internal fun KeyMappingLayer.isPrefix(keys: Iterable<KeyStroke>): Boolean {
+  if (keys is List<KeyStroke>) {
+    return isPrefix(keys)
+  }
+  return isPrefix(keys.toList())
 }
diff --git a/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyStrokeTrie.kt b/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyStrokeTrie.kt
index ca4d35b5b..a61984761 100644
--- a/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyStrokeTrie.kt
+++ b/vim-engine/src/main/kotlin/com/maddyhome/idea/vim/key/KeyStrokeTrie.kt
@@ -27,7 +27,7 @@ class KeyStrokeTrie<T>(private val name: String) {
     val debugString: String
   }
 
-  private class TrieNodeImpl<T>(val name: String, val depth: Int, override val data: T?)
+  private class TrieNodeImpl<T>(val name: String, val depth: Int, override var data: T?)
     : TrieNode<T> {
 
     val children = lazy { mutableMapOf<KeyStroke, TrieNodeImpl<T>>() }
@@ -86,6 +86,9 @@ class KeyStrokeTrie<T>(private val name: String) {
         TrieNodeImpl(name, current.depth + 1, if (i == keyStrokes.lastIndex) data else null)
       }
     }
+
+    // Last write wins (also means we can't cache results)
+    current.data = data
   }
 
   /**
@@ -118,6 +121,41 @@ class KeyStrokeTrie<T>(private val name: String) {
     return current
   }
 
+  /**
+   * Returns true if the given keys are a prefix to a longer sequence of keys
+   *
+   * Will return true even if the current keys map to a node with data.
+   */
+  fun isPrefix(keyStrokes: List<KeyStroke>): Boolean {
+    val node = getTrieNode(keyStrokes) as? TrieNodeImpl<T> ?: return false
+    return node.children.isInitialized() && node.children.value.isNotEmpty()
+  }
+
+  fun remove(keys: List<KeyStroke>) {
+    val path = buildList {
+      var current = root
+      keys.forEach { key ->
+        if (!current.children.isInitialized()) return
+        val next = current.children.value[key] ?: return
+        add(Pair(current, key))
+        current = next
+      }
+    }
+
+    path.asReversed().forEach { (parent, key) ->
+      val child = parent.children.value[key] ?: return
+      if (child.children.isInitialized() && child.children.value.isNotEmpty()) return
+      parent.children.value.remove(key)
+      if (parent.children.value.isNotEmpty() || parent.data != null) return
+    }
+  }
+
+  fun clear() {
+    if (root.children.isInitialized()) {
+      root.children.value.clear()
+    }
+  }
+
   override fun toString(): String {
     val children = if (root.children.isInitialized()) {
       "${root.children.value.size} children"