OpenJDK / lambda / lambda / jdk
changeset 7808:f7345843a1d6
override default methods for CopyOnWriteArrayList's subLists
author | akhil |
---|---|
date | Wed, 03 Apr 2013 22:00:15 -0700 |
parents | 68138df9be76 |
children | 8c1a7e894916 |
files | src/share/classes/java/util/concurrent/CopyOnWriteArrayList.java test/java/util/CollectionExtensionMethods/CollectionExtensionMethodsTest.java test/java/util/CollectionExtensionMethods/ListExtensionMethodsTest.java |
diffstat | 3 files changed, 316 insertions(+), 37 deletions(-) [+] |
line wrap: on
line diff
--- a/src/share/classes/java/util/concurrent/CopyOnWriteArrayList.java Wed Apr 03 18:19:56 2013 -0400 +++ b/src/share/classes/java/util/concurrent/CopyOnWriteArrayList.java Wed Apr 03 22:00:15 2013 -0700 @@ -1293,6 +1293,57 @@ } } + @Override + public void forEach(Consumer<? super E> action) { + @SuppressWarnings("unchecked") + final E[] elements = (E[]) l.getArray(); + checkForComodification(); + l.forEach(action, elements, offset, offset + size); + } + + @Override + public void sort(Comparator<? super E> c) { + final ReentrantLock lock = l.lock; + lock.lock(); + try { + checkForComodification(); + l.sort(c, offset, offset + size); + expectedArray = l.getArray(); + } finally { + lock.unlock(); + } + } + + @Override + public boolean removeIf(Predicate<? super E> filter) { + Objects.requireNonNull(filter); + final ReentrantLock lock = l.lock; + lock.lock(); + try { + checkForComodification(); + final int removeCount = + l.removeIf(filter, offset, offset + size); + expectedArray = l.getArray(); + size -= removeCount; + return removeCount > 0; + } finally { + lock.unlock(); + } + } + + @Override + public void replaceAll(UnaryOperator<E> operator) { + final ReentrantLock lock = l.lock; + lock.lock(); + try { + checkForComodification(); + l.replaceAll(operator, offset, offset + size); + expectedArray = l.getArray(); + } finally { + lock.unlock(); + } + } + public Spliterator<E> spliterator() { int lo = offset; int hi = offset + size; @@ -1377,13 +1428,18 @@ } } + @Override @SuppressWarnings("unchecked") - @Override public void forEach(Consumer<? super E> action) { + forEach(action, (E[]) getArray(), 0, size()); + } + + private void forEach(Consumer<? super E> action, + final E[] elements, + final int from, final int to) { Objects.requireNonNull(action); - final Object[] elements = getArray(); - for (final Object element : elements) { - action.accept((E) element); + for (int i = from; i < to; i++) { + action.accept(elements[i]); } } @@ -1393,52 +1449,91 @@ final ReentrantLock lock = this.lock; lock.lock(); try { - @SuppressWarnings("unchecked") - final E[] elements = (E[]) getArray(); - final E[] newElements = Arrays.copyOf(elements, elements.length); - Arrays.sort(newElements, c); - setArray(newElements); + sort(c, 0, size()); } finally { lock.unlock(); } } + // must be called with this.lock held + @SuppressWarnings("unchecked") + private void sort(Comparator<? super E> c, final int from, final int to) { + final E[] elements = (E[]) getArray(); + final E[] newElements = Arrays.copyOf(elements, elements.length); + for (int i = 0; i < from; i++) { + newElements[i] = elements[i]; + } + // only elements [from, to) are sorted + Arrays.sort(newElements, from, to, c); + for (int i = to; i < elements.length; i++) { + newElements[i] = elements[i]; + } + setArray(newElements); + } + @Override public boolean removeIf(Predicate<? super E> filter) { Objects.requireNonNull(filter); final ReentrantLock lock = this.lock; lock.lock(); try { + return removeIf(filter, 0, size()) > 0; + } finally { + lock.unlock(); + } + } + + // must be called with this.lock held + private int removeIf(Predicate<? super E> filter, final int from, final int to) { + Objects.requireNonNull(filter); + final ReentrantLock lock = this.lock; + lock.lock(); + try { @SuppressWarnings("unchecked") final E[] elements = (E[]) getArray(); - final int size = elements.length; // figure out which elements are to be removed // any exception thrown from the filter predicate at this stage // will leave the collection unmodified int removeCount = 0; - final BitSet removeSet = new BitSet(size); - for (int i=0; i < size; i++) { - final E element = elements[i]; + final int range = to - from; + final BitSet removeSet = new BitSet(range); + for (int i = 0; i < range; i++) { + final E element = elements[from + i]; if (filter.test(element)) { + // removeSet is zero-based to keep its size small removeSet.set(i); removeCount++; } } // copy surviving elements into a new array - final boolean anyToRemove = removeCount > 0; - if (anyToRemove) { + if (removeCount > 0) { final int newSize = elements.length - removeCount; - final Object[] newElements = new Object[newSize]; - for (int i=0, j=0; (i < size) && (j < newSize); i++, j++) { + final int newRange = newSize - from; + @SuppressWarnings("unchecked") + final E[] newElements = (E[]) new Object[newSize]; + // copy elements before [from, to) unmodified + for (int i = 0; i < from; i++) { + newElements[i] = elements[i]; + } + // elements [from, to) are subject to removal + int j = from; + for (int i = 0; (i < range) && (j < newRange); i++) { i = removeSet.nextClearBit(i); - newElements[j] = elements[i]; + if (i >= range) { + break; + } + newElements[j++] = elements[from + i]; + } + // copy any remaining elements beyond [from, to) + for (int i = to; (i < elements.length) && (j < newSize); i++) { + newElements[j++] = elements[i]; } setArray(newElements); } - return anyToRemove; + return removeCount; } finally { lock.unlock(); } @@ -1450,17 +1545,27 @@ final ReentrantLock lock = this.lock; lock.lock(); try { - @SuppressWarnings("unchecked") - final E[] elements = (E[]) getArray(); - final int len = elements.length; - @SuppressWarnings("unchecked") - final E[] newElements = (E[]) new Object[len]; - for (int i=0; i < len; i++) { - newElements[i] = operator.apply(elements[i]); - } - setArray(newElements); + replaceAll(operator, 0, size()); } finally { lock.unlock(); } } + + // must be called with this.lock held + @SuppressWarnings("unchecked") + private void replaceAll(UnaryOperator<E> operator, final int from, final int to) { + final E[] elements = (E[]) getArray(); + final E[] newElements = (E[]) new Object[elements.length]; + for (int i = 0; i < from; i++) { + newElements[i] = elements[i]; + } + // the operator is only applied to elements [from, to) + for (int i = from; i < to; i++) { + newElements[i] = operator.apply(elements[i]); + } + for (int i = to; i < elements.length; i++) { + newElements[i] = elements[i]; + } + setArray(newElements); + } }
--- a/test/java/util/CollectionExtensionMethods/CollectionExtensionMethodsTest.java Wed Apr 03 18:19:56 2013 -0400 +++ b/test/java/util/CollectionExtensionMethods/CollectionExtensionMethodsTest.java Wed Apr 03 22:00:15 2013 -0700 @@ -54,9 +54,11 @@ "java.util.TreeSet" }; + private static final int SIZE = 100; + @Test public void testForEach() throws Exception { - final CollectionSupplier supplier = new CollectionSupplier(SET_CLASSES, 100); + final CollectionSupplier supplier = new CollectionSupplier(SET_CLASSES, SIZE); for (final CollectionSupplier.TestCase test : supplier.get()) { final Set<Integer> original = ((Set<Integer>) test.original); final Set<Integer> set = ((Set<Integer>) test.collection); @@ -74,7 +76,7 @@ @Test public void testRemoveIf() throws Exception { - final CollectionSupplier supplier = new CollectionSupplier(SET_CLASSES, 100); + final CollectionSupplier supplier = new CollectionSupplier(SET_CLASSES, SIZE); for (final CollectionSupplier.TestCase test : supplier.get()) { final Set<Integer> original = ((Set<Integer>) test.original); final Set<Integer> set = ((Set<Integer>) test.collection);
--- a/test/java/util/CollectionExtensionMethods/ListExtensionMethodsTest.java Wed Apr 03 18:19:56 2013 -0400 +++ b/test/java/util/CollectionExtensionMethods/ListExtensionMethodsTest.java Wed Apr 03 22:00:15 2013 -0700 @@ -23,16 +23,19 @@ * questions. */ +import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.Comparators; import java.util.List; import java.util.LinkedList; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.testng.annotations.Test; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -70,9 +73,29 @@ private static final Comparator<AtomicInteger> ATOMIC_INTEGER_COMPARATOR = (x, y) -> x.intValue() - y.intValue(); + private static final int SIZE = 100; + private static final int SUBLIST_FROM = 2; + private static final int SUBLIST_TO = SIZE - SUBLIST_FROM; + private static final int SUBLIST_SIZE = SUBLIST_TO - SUBLIST_FROM; + + private static interface Callback { + void call(List<Integer> list); + } + + // call the callback for each recursive subList + private void trimmedSubList(final List<Integer> list, final Callback callback) { + int size = list.size(); + if (size > 1) { + // trim 1 element from both ends + final List<Integer> subList = list.subList(1, size - 1); + callback.call(subList); + trimmedSubList(subList, callback); + } + } + @Test public void testForEach() throws Exception { - final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, 100); + final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, SIZE); for (final CollectionSupplier.TestCase test : supplier.get()) { final List<Integer> original = ((List<Integer>) test.original); final List<Integer> list = ((List<Integer>) test.collection); @@ -80,12 +103,41 @@ list.forEach(actual::add); CollectionAsserts.assertContents(actual, list); CollectionAsserts.assertContents(actual, original); + + if (original.size() > SUBLIST_SIZE) { + final List<Integer> subList = original.subList(SUBLIST_FROM, SUBLIST_TO); + final List<Integer> actualSubList = new LinkedList<>(); + subList.forEach(actualSubList::add); + assertEquals(actualSubList.size(), SUBLIST_SIZE); + for (int i = 0; i < SUBLIST_SIZE; i++) { + assertEquals(actualSubList.get(i), original.get(i + SUBLIST_FROM)); + } + } + + trimmedSubList(list, new Callback() { + @Override + public void call(final List<Integer> list) { + final List<Integer> actual = new LinkedList<>(); + list.forEach(actual::add); + CollectionAsserts.assertContents(actual, list); + } + }); } } @Test public void testRemoveIf() throws Exception { - final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, 100); + final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, SIZE); + + for (final CollectionSupplier.TestCase test : supplier.get()) { + final List<Integer> original = ((List<Integer>) test.original); + final List<Integer> list = ((List<Integer>) test.collection); + final AtomicInteger offset = new AtomicInteger(1); + while (list.size() > 0) { + removeFirst(original, list, offset); + } + } + for (final CollectionSupplier.TestCase test : supplier.get()) { final List<Integer> original = ((List<Integer>) test.original); final List<Integer> list = ((List<Integer>) test.collection); @@ -101,12 +153,66 @@ list.removeIf(pEven); assertTrue(list.isEmpty()); } + + for (final CollectionSupplier.TestCase test : supplier.get()) { + final List<Integer> original = ((List<Integer>) test.original); + final List<Integer> list = ((List<Integer>) test.collection); + final List<Integer> listCopy = new ArrayList<>(list); + if (original.size() > SUBLIST_SIZE) { + final List<Integer> subList = list.subList(SUBLIST_FROM, SUBLIST_TO); + final List<Integer> subListCopy = new ArrayList<>(subList); + listCopy.removeAll(subList); + subList.removeIf(pOdd); + for (int i : subList) { + assertTrue((i % 2) == 0); + } + for (int i : subListCopy) { + if (i % 2 == 0) { + assertTrue(subList.contains(i)); + } else { + assertFalse(subList.contains(i)); + } + } + subList.removeIf(pEven); + assertTrue(subList.isEmpty()); + // elements outside the view should remain + CollectionAsserts.assertContents(list, listCopy); + } + } + + for (final CollectionSupplier.TestCase test : supplier.get()) { + final List<Integer> list = ((List<Integer>) test.collection); + trimmedSubList(list, new Callback() { + @Override + public void call(final List<Integer> list) { + final List<Integer> copy = new ArrayList<>(list); + list.removeIf(pOdd); + for (int i : list) { + assertTrue((i % 2) == 0); + } + for (int i : copy) { + if (i % 2 == 0) { + assertTrue(list.contains(i)); + } else { + assertFalse(list.contains(i)); + } + } + } + }); + } + } + + // remove the first element + private void removeFirst(final List<Integer> original, final List<Integer> list, final AtomicInteger offset) { + final AtomicBoolean first = new AtomicBoolean(true); + list.removeIf(x -> first.getAndSet(false)); + CollectionAsserts.assertContents(original.subList(offset.getAndIncrement(), original.size()), list); } @Test public void testReplaceAll() throws Exception { final int scale = 3; - final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, 100); + final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, SIZE); for (final CollectionSupplier.TestCase test : supplier.get()) { final List<Integer> original = ((List<Integer>) test.original); final List<Integer> list = ((List<Integer>) test.collection); @@ -114,12 +220,47 @@ for (int i=0; i < original.size(); i++) { assertTrue(list.get(i) == (scale * original.get(i)), "mismatch at index " + i); } + + if (original.size() > SUBLIST_SIZE) { + final List<Integer> subList = list.subList(SUBLIST_FROM, SUBLIST_TO); + subList.replaceAll(x -> x + 1); + // verify elements in view [from, to) were replaced + for (int i = 0; i < SUBLIST_SIZE; i++) { + assertTrue(subList.get(i) == ((scale * original.get(i + SUBLIST_FROM)) + 1), + "mismatch at sublist index " + i); + } + // verify that elements [0, from) remain unmodified + for (int i = 0; i < SUBLIST_FROM; i++) { + assertTrue(list.get(i) == (scale * original.get(i)), + "mismatch at original index " + i); + } + // verify that elements [to, size) remain unmodified + for (int i = SUBLIST_TO; i < list.size(); i++) { + assertTrue(list.get(i) == (scale * original.get(i)), + "mismatch at original index " + i); + } + } + } + + for (final CollectionSupplier.TestCase test : supplier.get()) { + final List<Integer> list = ((List<Integer>) test.collection); + trimmedSubList(list, new Callback() { + @Override + public void call(final List<Integer> list) { + final List<Integer> copy = new ArrayList<>(list); + final int offset = 5; + list.replaceAll(x -> offset + x); + for (int i=0; i < copy.size(); i++) { + assertTrue(list.get(i) == (offset + copy.get(i)), "mismatch at index " + i); + } + } + }); } } @Test public void testSort() throws Exception { - final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, 100); + final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, SIZE); for (final CollectionSupplier.TestCase test : supplier.get()) { final List<Integer> original = ((List<Integer>) test.original); final List<Integer> list = ((List<Integer>) test.collection); @@ -173,12 +314,43 @@ for (int i=0; i < test.original.size(); i++) { assertEquals(i, incomparables.get(i).intValue()); } + + if (original.size() > SUBLIST_SIZE) { + final List<Integer> copy = new ArrayList<>(list); + final List<Integer> subList = list.subList(SUBLIST_FROM, SUBLIST_TO); + CollectionSupplier.shuffle(subList); + subList.sort(Comparators.<Integer>naturalOrder()); + CollectionAsserts.assertSorted(subList, Comparators.<Integer>naturalOrder()); + // verify that elements [0, from) remain unmodified + for (int i = 0; i < SUBLIST_FROM; i++) { + assertTrue(list.get(i) == copy.get(i), + "mismatch at index " + i); + } + // verify that elements [to, size) remain unmodified + for (int i = SUBLIST_TO; i < list.size(); i++) { + assertTrue(list.get(i) == copy.get(i), + "mismatch at index " + i); + } + } + } + + for (final CollectionSupplier.TestCase test : supplier.get()) { + final List<Integer> list = ((List<Integer>) test.collection); + trimmedSubList(list, new Callback() { + @Override + public void call(final List<Integer> list) { + final List<Integer> copy = new ArrayList<>(list); + CollectionSupplier.shuffle(list); + list.sort(Comparators.<Integer>naturalOrder()); + CollectionAsserts.assertSorted(list, Comparators.<Integer>naturalOrder()); + } + }); } } @Test public void testRemoveIfThrowsCME() throws Exception { - final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, 100); + final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, SIZE); for (final CollectionSupplier.TestCase test : supplier.get()) { final List<Integer> list = ((List<Integer>) test.collection); if (list.size() <= 1) { @@ -199,7 +371,7 @@ @Test public void testReplaceAllThrowsCME() throws Exception { - final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, 100); + final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, SIZE); for (final CollectionSupplier.TestCase test : supplier.get()) { final List<Integer> list = ((List<Integer>) test.collection); if (list.size() <= 1) { @@ -220,7 +392,7 @@ @Test public void testSortThrowsCME() throws Exception { - final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, 100); + final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, SIZE); for (final CollectionSupplier.TestCase test : supplier.get()) { final List<Integer> list = ((List<Integer>) test.collection); if (list.size() <= 1) {