Michael.W基于Foundry精读Openzeppelin第19期——EnumerableSet.sol

  • Michael.W
  • 更新于 2023-08-05 23:55
  • 阅读 1269

EnumerableSet库提供了Bytes32Set、AddressSet和UintSet三种类型的set,分别适用于bytes32、address和uint256类型的元素。 每种set都提供了对应的增添元素、删除元素、查询当前set中元素个数等操作。几乎所有操作的时间复杂度均为O(1)。

0. 版本

[openzeppelin]:v4.8.3,[forge-std]:v1.5.6

0.1 EnumerableSet.sol

Github: https://github.com/OpenZeppelin/openzeppelin-contracts/blob/v4.8.3/contracts/utils/structs/EnumerableSet.sol

EnumerableSet库提供了Bytes32Set、AddressSet和UintSet三种类型的set,分别用于bytes32、address和uint256类型的元素。 每种set都提供了对应的增添元素、删除元素、检查目标元素是否处于set中、查询当前set中元素个数等操作。几乎所有操作的时间复杂度均为O(1)。

1. 目标合约

封装EnumerableSet library成为一个可调用合约:

Github: https://github.com/RevelationOfTuring/foundry-openzeppelin-contracts/blob/master/src/utils/structs/MockEnumerableSet.sol

// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;

import "openzeppelin-contracts/contracts/utils/structs/EnumerableSet.sol";

contract MockBytes32Set {
    using EnumerableSet for EnumerableSet.Bytes32Set;

    EnumerableSet.Bytes32Set _bytes32Set;

    function add(bytes32 value) external returns (bool) {
        return _bytes32Set.add(value);
    }

    function remove(bytes32 value) external returns (bool){
        return _bytes32Set.remove(value);
    }

    function contains(bytes32 value) external view returns (bool) {
        return _bytes32Set.contains(value);
    }

    function length() external view returns (uint) {
        return _bytes32Set.length();
    }

    function at(uint index) external view returns (bytes32){
        return _bytes32Set.at(index);
    }

    function values() external view returns (bytes32[] memory){
        return _bytes32Set.values();
    }
}

contract MockAddressSet {
    using EnumerableSet for EnumerableSet.AddressSet;

    EnumerableSet.AddressSet _addressSet;

    function add(address value) external returns (bool) {
        return _addressSet.add(value);
    }

    function remove(address value) external returns (bool){
        return _addressSet.remove(value);
    }

    function contains(address value) external view returns (bool) {
        return _addressSet.contains(value);
    }

    function length() external view returns (uint) {
        return _addressSet.length();
    }

    function at(uint index) external view returns (address){
        return _addressSet.at(index);
    }

    function values() external view returns (address[] memory){
        return _addressSet.values();
    }
}

contract MockUintSet {
    using EnumerableSet for EnumerableSet.UintSet;

    EnumerableSet.UintSet _uintSet;

    function add(uint value) external returns (bool) {
        return _uintSet.add(value);
    }

    function remove(uint value) external returns (bool){
        return _uintSet.remove(value);
    }

    function contains(uint value) external view returns (bool) {
        return _uintSet.contains(value);
    }

    function length() external view returns (uint) {
        return _uintSet.length();
    }

    function at(uint index) external view returns (uint){
        return _uintSet.at(index);
    }

    function values() external view returns (uint[] memory){
        return _uintSet.values();
    }
}

全部foundry测试合约:

Github: https://github.com/RevelationOfTuring/foundry-openzeppelin-contracts/blob/master/test/utils/structs/EnumerableSet.t.sol

2. 代码精读

2.1 结构体Set

结构体Set是由一个存储set中元素值的bytes32数组和一个用于记录元素值在元素数组中的index的mapping构成:

    struct Set {
        // 用于存放set内元素值的数组。存储类型为bytes32,根据需求可以对此进行适当修改
        bytes32[] _values;

        // 用于记录元素值在_values数组中的index的mapping。如果一个元素值的index记录值为0,表示该元素不存在于set中
        mapping(bytes32 => uint256) _indexes;
    }

结构体Set和它对应的方法都不对外开放。

2.1.1 _contains(Set storage set, bytes32 value) && _length(Set storage set) && _at(Set storage set, uint256 index) &&
  • _contains(Set storage set, bytes32 value):查看元素value是否存在于set中。如果存在,返回true。时间复杂度为O(1);
  • _length(Set storage set):返回当前set中的元素个数。时间复杂度为O(1);
  • _at(Set storage set, uint256 index):返回当前set中对应index位置上的元素值。时间复杂度为O(1)。注意:方法内无索引越界检查,所以使用时需要保证传入的index < set中的元素总个数;
  • _values(Set storage set):返回当前set中全部的元素值(无序)。注意:该方法内部会将storage数组中的全部元素复制到memory中,这将消耗大量gas。所以请不要在非view方法中调用该方法。
    function _contains(Set storage set, bytes32 value) private view returns (bool) {
        // 如果记录的value元素对应index为0表示不存在,不为0表示存在
        return set._indexes[value] != 0;
    }

    function _length(Set storage set) private view returns (uint256) {
        // 返回Set._values的长度
        return set._values.length;
    }

    function _at(Set storage set, uint256 index) private view returns (bytes32) {
        // 直接从Set._values中用index取值
        return set._values[index];
    }

    function _values(Set storage set) private view returns (bytes32[] memory) {
        // 直接将整个bytes32[] storage复制到memory中返回
        return set._values;
    }
2.1.2 _add(Set storage set, bytes32 value)

向set中增添元素。如果该元素为非set元素返回true,否则返回false。时间复杂度为O(1)

    function _add(Set storage set, bytes32 value) private returns (bool) {
        if (!_contains(set, value)) {
            // 如果元素value不存在于当前set中,向Set._values中添加该元素
            set._values.push(value);
            // 在Set._indexes中记录该元素value位于Set._values数组中的index——即当前Set._values数组的长度。
            // 注:按照传统编程思想该元素位于Set._values数组最后,其index应该为总长度-1。这里对所有的元素的index记录值都+1,其目的是为了将index 0作为非set元素的flag。如果不这么设计,第一个添加的元素的index就是0,这将导致查询该元素是否处于set中的结果不符合预期
            set._indexes[value] = set._values.length;
            // 增添了新元素返回true
            return true;
        } else {
            // 如果元素value已存在于set中,不进行任何添加操作并返回false
            return false;
        }
    }
2.1.3 _remove(Set storage set, bytes32 value)

从set中移除元素。如果该元素为当前set元素返回true,否则返回false。时间复杂度为O(1)

    function _remove(Set storage set, bytes32 value) private returns (bool) {
        // 获取元素value位于set中的index
        uint256 valueIndex = set._indexes[value];

        if (valueIndex != 0) {
            // 如果valueIndex不为0,表示该元素处于当前set中
            // 从一个数组删除某给位置的元素的思路:将数组最后一个元素复制到待删除元素的位置上,然后将最后一个元素pop。该操作的时间复杂度为O(1)
            // valueIndex-1为待删除元素在Set._values数组中的真实index
            uint256 toDeleteIndex = valueIndex - 1;
            // lastIndex为当前数组最后一个元素的真实index
            uint256 lastIndex = set._values.length - 1;

            if (lastIndex != toDeleteIndex) {
                // 如果待删除元素非数组内最后一个元素,取出数组最后一个元素的值
                bytes32 lastValue = set._values[lastIndex];
                // 数组待删除元素位置上的值替换为当前数组最后一个元素。其实此时已经实现了目标元素真正意义上的删除
                set._values[toDeleteIndex] = lastValue;
                // 由于数组最后一个元素已经换了位置,更新Set._indexes中最后一个元素的index为valueIndex,即待删除元素在Set._indexes中记录的index值
                set._indexes[lastValue] = valueIndex; 
            }

            // 直接pop掉Set._values中的最后一个元素
            set._values.pop();
            // 删除Set._indexes中关于待删除元素的记录
            delete set._indexes[value];
            // 返回true
            return true;
        } else {
            // 如果待删除元素value非set中的元素,直接返回false
            return false;
        }
    }

2.2 Bytes32Set体系

如果要存储的元素类型为bytes32,可以采用该体系

    struct Bytes32Set {
        // 封装了一个Set
        Set _inner;
    }
2.2.1 add(Bytes32Set storage set, bytes32 value) && remove(Bytes32Set storage set, bytes32 value)
  • add(Bytes32Set storage set, bytes32 value):向Bytes32Set中增添元素。如果该元素为非set元素返回true,否则返回false。时间复杂度为O(1);
  • remove(Bytes32Set storage set, bytes32 value):从Bytes32Set中移除元素。如果该元素为当前set元素返回true,否则返回false。时间复杂度为O(1)。
    function add(Bytes32Set storage set, bytes32 value) internal returns (bool) {
        // 直接调用Set._add()方法
        return _add(set._inner, value);
    }

    function remove(Bytes32Set storage set, bytes32 value) internal returns (bool) {
        // 直接调用Set._remove()方法
        return _remove(set._inner, value);
    }
2.2.2 contains(Bytes32Set storage set, bytes32 value) && length(Bytes32Set storage set) && at(Bytes32Set storage set, uint256 index) && values(Bytes32Set storage set)
  • contains(Bytes32Set storage set, bytes32 value):查看元素value是否存在于Bytes32Set中。如果存在,返回true。时间复杂度为O(1);
  • length(Bytes32Set storage set):返回当前Bytes32Set中的元素个数。时间复杂度为O(1);
  • at(Bytes32Set storage set, uint256 index):返回当前Bytes32Set中对应index位置上的元素值。时间复杂度为O(1)。注意:方法内无索引越界检查,所以使用时需要保证传入的index < Bytes32Set中的元素总个数;
  • values(Bytes32Set storage set):返回当前Bytes32Set中全部的元素值(无序)。注意:该方法内部会将storage数组中的全部元素复制到memory中,这将消耗大量gas。所以请不要在非view方法中调用该方法。
    function contains(Bytes32Set storage set, bytes32 value) internal view returns (bool) {
        // 直接调用Set._contains()方法
        return _contains(set._inner, value);
    }

    function length(Bytes32Set storage set) internal view returns (uint256) {
        // 直接调用Set._length()方法
        return _length(set._inner);
    }

    function at(Bytes32Set storage set, uint256 index) internal view returns (bytes32) {
        // 直接调用Set._at()方法
        return _at(set._inner, index);
    }

    function values(Bytes32Set storage set) internal view returns (bytes32[] memory) {
        // 直接调用Set._values()方法得到底层set中存储的元素总集(是一个bytes32[])
        bytes32[] memory store = _values(set._inner);
        // 将store转换成Bytes32Set的外层封装类型bytes32[]
        // 注:个人认为对于Bytes32Set的values()方法可以直接返回store,不需要再做类型转换。由于本库的代码是由js代码生成的,所以此处没与后面的uint[]和address[]做差别处理
        bytes32[] memory result;
        /// @solidity memory-safe-assembly
        assembly {
            // 内联汇编中,直接在memory中进行bytes32[]->bytes32[]的类型转换
            result := store
        }
    // 返回类型转换后的bytes32[]
        return result;
    }
2.2.3 foundry代码验证
contract EnumerableSetTest is Test {
    MockBytes32Set mbs = new MockBytes32Set();

    function test_Bytes32Set_Operations() external {
        // empty
        assertEq(mbs.length(), 0);
        assertEq(mbs.values().length, 0);
        assertFalse(mbs.contains('a'));

        // add
        assertTrue(mbs.add('a'));
        assertTrue(mbs.contains('a'));
        assertEq(mbs.length(), 1);
        assertTrue(mbs.add('b'));
        assertEq(mbs.length(), 2);
        // add 'a' again
        assertFalse(mbs.add('a'));
        assertEq(mbs.length(), 2);
        bytes32[] memory values = mbs.values();
        assertEq('a', values[0]);
        assertEq('b', values[1]);

        assertTrue(mbs.add('c'));
        assertTrue(mbs.add('d'));
        assertEq(mbs.length(), 4);

        // remove
        // inner array: ['a','b','c','d']
        assertTrue(mbs.contains('b'));
        assertTrue(mbs.remove('b'));
        assertFalse(mbs.contains('b'));
        assertEq(mbs.length(), 3);
        // remove 'b' again
        assertFalse(mbs.remove('b'));
        assertEq(mbs.length(), 3);
        // inner array after remove: ['a','d','c']
        assertEq(mbs.at(0), 'a');
        assertEq(mbs.at(1), 'd');
        assertEq(mbs.at(2), 'c');
        // check values()
        values = mbs.values();
        assertEq('a', values[0]);
        assertEq('d', values[1]);
        assertEq('c', values[2]);

        // revert if out of bounds
        vm.expectRevert();
        mbs.at(1024);
    }
}

2.3 AddressSet体系

如果要存储的元素类型为address,可以采用该体系

    struct AddressSet {
        // 封装了一个Set
        Set _inner;
    }
2.3.1 add(AddressSet storage set, address value) && remove(AddressSet storage set, address value)
  • add(AddressSet storage set, address value):向AddressSet中增添元素。如果该元素为非set元素返回true,否则返回false。时间复杂度为O(1);
  • remove(AddressSet storage set, address value): 从AddressSet中移除元素。如果该元素为当前set元素返回true,否则返回false。时间复杂度为O(1)。
    function add(AddressSet storage set, address value) internal returns (bool) {
        // 直接调用Set._add()方法,参数value做了address->bytes32的类型转换
        return _add(set._inner, bytes32(uint256(uint160(value))));
    }

    function remove(AddressSet storage set, address value) internal returns (bool) {
        // 直接调用Set._remove()方法,参数value做了address->bytes32的类型转换
        return _remove(set._inner, bytes32(uint256(uint160(value))));
    }
2.3.2 contains(AddressSet storage set, address value) && length(AddressSet storage set) && at(AddressSet storage set, uint256 index) && values(AddressSet storage set)
  • contains(AddressSet storage set, address value):查看元素value是否存在于AddressSet中。如果存在,返回true。时间复杂度为O(1);
  • length(AddressSet storage set):返回当前AddressSet中的元素个数。时间复杂度为O(1);
  • at(AddressSet storage set, uint256 index):返回当前AddressSet中对应index位置上的元素值。时间复杂度为O(1)。注意:方法内无索引越界检查,所以使用时需要保证传入的index < AddressSet中的元素总个数;
  • values(AddressSet storage set):返回当前AddressSet中全部的元素值(无序)。注意:该方法内部会将storage数组中的全部元素复制到memory中,这将消耗大量gas。所以请不要在非view方法中调用该方法。
    function contains(AddressSet storage set, address value) internal view returns (bool) {
        // 直接调用Set._contains()方法,参数value做了address->bytes32的类型转换
        return _contains(set._inner, bytes32(uint256(uint160(value))));
    }

    function length(AddressSet storage set) internal view returns (uint256) {
        // 直接调用Set._length()方法
        return _length(set._inner);
    }

    function at(AddressSet storage set, uint256 index) internal view returns (address) {
        // 直接调用Set._at()方法,并将bytes32类型的返回值转换为address类型返回
        return address(uint160(uint256(_at(set._inner, index))));
    }

    function values(AddressSet storage set) internal view returns (address[] memory) {
        // 直接调用Set._values()方法得到底层set中存储的元素总集(是一个bytes32[])
        bytes32[] memory store = _values(set._inner);
        // 将store转换成AddressSet的外层封装类型address[]
        address[] memory result;

        /// @solidity memory-safe-assembly
        assembly {
            // 内联汇编中,直接在memory中进行bytes32[]->address[]的类型转换
            result := store
        }
    // 返回类型转换后的address[]
        return result;
    }
2.3.3 foundry代码验证
contract EnumerableSetTest is Test {
    MockAddressSet mas = new MockAddressSet();

    function test_AddressSet_Operations() external {
        // empty
        assertEq(mas.length(), 0);
        assertEq(mas.values().length, 0);
        assertFalse(mas.contains(address(1)));

        // add
        assertTrue(mas.add(address(1)));
        assertTrue(mas.contains(address(1)));
        assertEq(mas.length(), 1);
        assertTrue(mas.add(address(2)));
        assertEq(mas.length(), 2);
        // add address(1) again
        assertFalse(mas.add(address(1)));
        assertEq(mas.length(), 2);
        address[] memory values = mas.values();
        assertEq(address(1), values[0]);
        assertEq(address(2), values[1]);

        assertTrue(mas.add(address(4)));
        assertTrue(mas.add(address(8)));
        assertEq(mas.length(), 4);

        // remove
        // inner array: [address(1),address(2),address(4),address(8)]
        assertTrue(mas.contains(address(2)));
        assertTrue(mas.remove(address(2)));
        assertFalse(mas.contains(address(2)));
        assertEq(mas.length(), 3);
        // remove address(2) again
        assertFalse(mas.remove(address(2)));
        assertEq(mas.length(), 3);
        // inner array after remove: [address(1),address(8),address(4)]
        assertEq(mas.at(0), address(1));
        assertEq(mas.at(1), address(8));
        assertEq(mas.at(2), address(4));
        // check values()
        values = mas.values();
        assertEq(address(1), values[0]);
        assertEq(address(8), values[1]);
        assertEq(address(4), values[2]);

        // revert if out of bounds
        vm.expectRevert();
        mas.at(1024);
    }
}

2.4 UintSet体系

如果要存储的元素类型为uint256,可以采用该体系

    struct UintSet {
        // 封装了一个Set
        Set _inner;
    }
2.4.1 add(UintSet storage set, uint256 value) && remove(UintSet storage set, uint256 value)
  • add(UintSet storage set, uint256 value):向UintSet中增添元素。如果该元素为非set元素返回true,否则返回false。时间复杂度为O(1);
  • remove(UintSet storage set, uint256 value):从UintSet中移除元素。如果该元素为当前set元素返回true,否则返回false。时间复杂度为O(1)。
    function add(UintSet storage set, uint256 value) internal returns (bool) {
        // 直接调用Set._add()方法,参数value做了uint256->bytes32的类型转换
        return _add(set._inner, bytes32(value));
    }

    function remove(UintSet storage set, uint256 value) internal returns (bool) {
        // 直接调用Set._remove()方法,参数value做了uint256->bytes32的类型转换
        return _remove(set._inner, bytes32(value));
    }
2.4.2 contains(UintSet storage set, uint256 value) && length(UintSet storage set) && at(UintSet storage set, uint256 index) && values(UintSet storage set)
  • contains(UintSet storage set, uint256 value):查看元素value是否存在于UintSet中。如果存在,返回true。时间复杂度为O(1);
  • length(UintSet storage set):返回当前UintSet中的元素个数。时间复杂度为O(1);
  • at(UintSet storage set, uint256 index):返回当前UintSet中对应index位置上的元素值。时间复杂度为O(1)。注意:方法内无索引越界检查,所以使用时需要保证传入的index < UintSet中的元素总个数;
  • values(UintSet storage set):返回当前UintSet中全部的元素值(无序)。注意:该方法内部会将storage数组中的全部元素复制到memory中,这将消耗大量gas。所以请不要在非view方法中调用该方法。
    function contains(UintSet storage set, uint256 value) internal view returns (bool) {
        // 直接调用Set._contains()方法,参数value做了uint256->bytes32的类型转换
        return _contains(set._inner, bytes32(value));
    }

    function length(UintSet storage set) internal view returns (uint256) {
        // 直接调用Set._length()方法
        return _length(set._inner);
    }

    function at(UintSet storage set, uint256 index) internal view returns (uint256) {
        // 直接调用Set._at()方法,并将bytes32类型的返回值转换为uint256类型返回
        return uint256(_at(set._inner, index));
    }

    function values(UintSet storage set) internal view returns (uint256[] memory) {
        // 直接调用Set._values()方法得到底层set中存储的元素总集(是一个bytes32[])
        bytes32[] memory store = _values(set._inner);
        // 将store转换成UintSet的外层封装类型uint256[]
        uint256[] memory result;

        /// @solidity memory-safe-assembly
        assembly {
            // 内联汇编中,直接在memory中进行bytes32[]->uint256[]的类型转换
            result := store
        }
    // 返回类型转换后的uint256[]
        return result;
    }
2.4.3 foundry代码验证
contract EnumerableSetTest is Test {
    MockUintSet mus = new MockUintSet();

    function test_UintSet_Operations() external {
        // empty
        assertEq(mus.length(), 0);
        assertEq(mus.values().length, 0);
        assertFalse(mus.contains(1));

        // add
        assertTrue(mus.add(1));
        assertTrue(mus.contains(1));
        assertEq(mus.length(), 1);
        assertTrue(mus.add(2));
        assertEq(mus.length(), 2);
        // add 1 again
        assertFalse(mus.add(1));
        assertEq(mus.length(), 2);
        uint[] memory values = mus.values();
        assertEq(1, values[0]);
        assertEq(2, values[1]);

        assertTrue(mus.add(4));
        assertTrue(mus.add(8));
        assertEq(mus.length(), 4);

        // remove
        // inner array: [1,2,4,8]
        assertTrue(mus.contains(2));
        assertTrue(mus.remove(2));
        assertFalse(mus.contains(2));
        assertEq(mus.length(), 3);
        // remove 2 again
        assertFalse(mus.remove(2));
        assertEq(mus.length(), 3);
        // inner array after remove: [1,8,4]
        assertEq(mus.at(0), 1);
        assertEq(mus.at(1), 8);
        assertEq(mus.at(2), 4);
        // check values()
        values = mus.values();
        assertEq(1, values[0]);
        assertEq(8, values[1]);
        assertEq(4, values[2]);

        // revert if out of bounds
        vm.expectRevert();
        mus.at(1024);
    }
}

ps:\ 本人热爱图灵,热爱中本聪,热爱V神。 以下是我个人的公众号,如果有技术问题可以关注我的公众号来跟我交流。 同时我也会在这个公众号上每周更新我的原创文章,喜欢的小伙伴或者老伙计可以支持一下! 如果需要转发,麻烦注明作者。十分感谢!

1.jpeg

公众号名称:后现代泼痞浪漫主义奠基人

点赞 0
收藏 0
分享
本文参与登链社区写作激励计划 ,好文好收益,欢迎正在阅读的你也加入。

0 条评论

请先 登录 后评论
Michael.W
Michael.W
0x93E7...0000
狂热的区块链爱好者