changeset 7808:f7345843a1d6

override default methods for CopyOnWriteArrayList's subLists
author akhil
date Wed, 03 Apr 2013 22:00:15 -0700
parents 68138df9be76
children 8c1a7e894916
files src/share/classes/java/util/concurrent/CopyOnWriteArrayList.java test/java/util/CollectionExtensionMethods/CollectionExtensionMethodsTest.java test/java/util/CollectionExtensionMethods/ListExtensionMethodsTest.java
diffstat 3 files changed, 316 insertions(+), 37 deletions(-) [+]
line wrap: on
line diff
--- a/src/share/classes/java/util/concurrent/CopyOnWriteArrayList.java	Wed Apr 03 18:19:56 2013 -0400
+++ b/src/share/classes/java/util/concurrent/CopyOnWriteArrayList.java	Wed Apr 03 22:00:15 2013 -0700
@@ -1293,6 +1293,57 @@
             }
         }
 
+        @Override
+        public void forEach(Consumer<? super E> action) {
+            @SuppressWarnings("unchecked")
+            final E[] elements = (E[]) l.getArray();
+            checkForComodification();
+            l.forEach(action, elements, offset, offset + size);
+        }
+
+        @Override
+        public void sort(Comparator<? super E> c) {
+            final ReentrantLock lock = l.lock;
+            lock.lock();
+            try {
+                checkForComodification();
+                l.sort(c, offset, offset + size);
+                expectedArray = l.getArray();
+            } finally {
+                lock.unlock();
+            }
+        }
+
+        @Override
+        public boolean removeIf(Predicate<? super E> filter) {
+            Objects.requireNonNull(filter);
+            final ReentrantLock lock = l.lock;
+            lock.lock();
+            try {
+                checkForComodification();
+                final int removeCount =
+                        l.removeIf(filter, offset, offset + size);
+                expectedArray = l.getArray();
+                size -= removeCount;
+                return removeCount > 0;
+            } finally {
+                lock.unlock();
+            }
+        }
+
+        @Override
+        public void replaceAll(UnaryOperator<E> operator) {
+            final ReentrantLock lock = l.lock;
+            lock.lock();
+            try {
+                checkForComodification();
+                l.replaceAll(operator, offset, offset + size);
+                expectedArray = l.getArray();
+            } finally {
+                lock.unlock();
+            }
+        }
+
         public Spliterator<E> spliterator() {
             int lo = offset;
             int hi = offset + size;
@@ -1377,13 +1428,18 @@
         }
     }
 
+    @Override
     @SuppressWarnings("unchecked")
-    @Override
     public void forEach(Consumer<? super E> action) {
+        forEach(action, (E[]) getArray(), 0, size());
+    }
+
+    private void forEach(Consumer<? super E> action,
+                         final E[] elements,
+                         final int from, final int to) {
         Objects.requireNonNull(action);
-        final Object[] elements = getArray();
-        for (final Object element : elements) {
-            action.accept((E) element);
+        for (int i = from; i < to; i++) {
+            action.accept(elements[i]);
         }
     }
 
@@ -1393,52 +1449,91 @@
         final ReentrantLock lock = this.lock;
         lock.lock();
         try {
-            @SuppressWarnings("unchecked")
-            final E[] elements = (E[]) getArray();
-            final E[] newElements = Arrays.copyOf(elements, elements.length);
-            Arrays.sort(newElements, c);
-            setArray(newElements);
+            sort(c, 0, size());
         } finally {
             lock.unlock();
         }
     }
 
+    // must be called with this.lock held
+    @SuppressWarnings("unchecked")
+    private void sort(Comparator<? super E> c, final int from, final int to) {
+        final E[] elements = (E[]) getArray();
+        final E[] newElements = Arrays.copyOf(elements, elements.length);
+        for (int i = 0; i < from; i++) {
+            newElements[i] = elements[i];
+        }
+        // only elements [from, to) are sorted
+        Arrays.sort(newElements, from, to, c);
+        for (int i = to; i < elements.length; i++) {
+            newElements[i] = elements[i];
+        }
+        setArray(newElements);
+    }
+
     @Override
     public boolean removeIf(Predicate<? super E> filter) {
         Objects.requireNonNull(filter);
         final ReentrantLock lock = this.lock;
         lock.lock();
         try {
+            return removeIf(filter, 0, size()) > 0;
+        } finally {
+            lock.unlock();
+        }
+    }
+
+    // must be called with this.lock held
+    private int removeIf(Predicate<? super E> filter, final int from, final int to) {
+        Objects.requireNonNull(filter);
+        final ReentrantLock lock = this.lock;
+        lock.lock();
+        try {
             @SuppressWarnings("unchecked")
             final E[] elements = (E[]) getArray();
-            final int size = elements.length;
 
             // figure out which elements are to be removed
             // any exception thrown from the filter predicate at this stage
             // will leave the collection unmodified
             int removeCount = 0;
-            final BitSet removeSet = new BitSet(size);
-            for (int i=0; i < size; i++) {
-                final E element = elements[i];
+            final int range = to - from;
+            final BitSet removeSet = new BitSet(range);
+            for (int i = 0; i < range; i++) {
+                final E element = elements[from + i];
                 if (filter.test(element)) {
+                    // removeSet is zero-based to keep its size small
                     removeSet.set(i);
                     removeCount++;
                 }
             }
 
             // copy surviving elements into a new array
-            final boolean anyToRemove = removeCount > 0;
-            if (anyToRemove) {
+            if (removeCount > 0) {
                 final int newSize = elements.length - removeCount;
-                final Object[] newElements = new Object[newSize];
-                for (int i=0, j=0; (i < size) && (j < newSize); i++, j++) {
+                final int newRange = newSize - from;
+                @SuppressWarnings("unchecked")
+                final E[] newElements = (E[]) new Object[newSize];
+                // copy elements before [from, to) unmodified
+                for (int i = 0; i < from; i++) {
+                    newElements[i] = elements[i];
+                }
+                // elements [from, to) are subject to removal
+                int j = from;
+                for (int i = 0; (i < range) && (j < newRange); i++) {
                     i = removeSet.nextClearBit(i);
-                    newElements[j] = elements[i];
+                    if (i >= range) {
+                        break;
+                    }
+                    newElements[j++] = elements[from + i];
+                }
+                // copy any remaining elements beyond [from, to)
+                for (int i = to; (i < elements.length) && (j < newSize); i++) {
+                    newElements[j++] = elements[i];
                 }
                 setArray(newElements);
             }
 
-            return anyToRemove;
+            return removeCount;
         } finally {
             lock.unlock();
         }
@@ -1450,17 +1545,27 @@
         final ReentrantLock lock = this.lock;
         lock.lock();
         try {
-            @SuppressWarnings("unchecked")
-            final E[] elements = (E[]) getArray();
-            final int len = elements.length;
-            @SuppressWarnings("unchecked")
-            final E[] newElements = (E[]) new Object[len];
-            for (int i=0; i < len; i++) {
-                newElements[i] = operator.apply(elements[i]);
-            }
-            setArray(newElements);
+            replaceAll(operator, 0, size());
         } finally {
             lock.unlock();
         }
     }
+
+    // must be called with this.lock held
+    @SuppressWarnings("unchecked")
+    private void replaceAll(UnaryOperator<E> operator, final int from, final int to) {
+        final E[] elements = (E[]) getArray();
+        final E[] newElements = (E[]) new Object[elements.length];
+        for (int i = 0; i < from; i++) {
+            newElements[i] = elements[i];
+        }
+        // the operator is only applied to elements [from, to)
+        for (int i = from; i < to; i++) {
+            newElements[i] = operator.apply(elements[i]);
+        }
+        for (int i = to; i < elements.length; i++) {
+            newElements[i] = elements[i];
+        }
+        setArray(newElements);
+    }
 }
--- a/test/java/util/CollectionExtensionMethods/CollectionExtensionMethodsTest.java	Wed Apr 03 18:19:56 2013 -0400
+++ b/test/java/util/CollectionExtensionMethods/CollectionExtensionMethodsTest.java	Wed Apr 03 22:00:15 2013 -0700
@@ -54,9 +54,11 @@
         "java.util.TreeSet"
     };
 
+    private static final int SIZE = 100;
+
     @Test
     public void testForEach() throws Exception {
-        final CollectionSupplier supplier = new CollectionSupplier(SET_CLASSES, 100);
+        final CollectionSupplier supplier = new CollectionSupplier(SET_CLASSES, SIZE);
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final Set<Integer> original = ((Set<Integer>) test.original);
             final Set<Integer> set = ((Set<Integer>) test.collection);
@@ -74,7 +76,7 @@
 
     @Test
     public void testRemoveIf() throws Exception {
-        final CollectionSupplier supplier = new CollectionSupplier(SET_CLASSES, 100);
+        final CollectionSupplier supplier = new CollectionSupplier(SET_CLASSES, SIZE);
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final Set<Integer> original = ((Set<Integer>) test.original);
             final Set<Integer> set = ((Set<Integer>) test.collection);
--- a/test/java/util/CollectionExtensionMethods/ListExtensionMethodsTest.java	Wed Apr 03 18:19:56 2013 -0400
+++ b/test/java/util/CollectionExtensionMethods/ListExtensionMethodsTest.java	Wed Apr 03 22:00:15 2013 -0700
@@ -23,16 +23,19 @@
  * questions.
  */
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.Comparators;
 import java.util.List;
 import java.util.LinkedList;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import org.testng.annotations.Test;
 
 import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
 import static org.testng.Assert.assertTrue;
 import static org.testng.Assert.fail;
 
@@ -70,9 +73,29 @@
     private static final Comparator<AtomicInteger> ATOMIC_INTEGER_COMPARATOR =
             (x, y) -> x.intValue() - y.intValue();
 
+    private static final int SIZE = 100;
+    private static final int SUBLIST_FROM = 2;
+    private static final int SUBLIST_TO = SIZE - SUBLIST_FROM;
+    private static final int SUBLIST_SIZE = SUBLIST_TO - SUBLIST_FROM;
+
+    private static interface Callback {
+        void call(List<Integer> list);
+    }
+
+    // call the callback for each recursive subList
+    private void trimmedSubList(final List<Integer> list, final Callback callback) {
+        int size = list.size();
+        if (size > 1) {
+            // trim 1 element from both ends
+            final List<Integer> subList = list.subList(1, size - 1);
+            callback.call(subList);
+            trimmedSubList(subList, callback);
+        }
+    }
+
     @Test
     public void testForEach() throws Exception {
-        final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, 100);
+        final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, SIZE);
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final List<Integer> original = ((List<Integer>) test.original);
             final List<Integer> list = ((List<Integer>) test.collection);
@@ -80,12 +103,41 @@
             list.forEach(actual::add);
             CollectionAsserts.assertContents(actual, list);
             CollectionAsserts.assertContents(actual, original);
+
+            if (original.size() > SUBLIST_SIZE) {
+                final List<Integer> subList = original.subList(SUBLIST_FROM, SUBLIST_TO);
+                final List<Integer> actualSubList = new LinkedList<>();
+                subList.forEach(actualSubList::add);
+                assertEquals(actualSubList.size(), SUBLIST_SIZE);
+                for (int i = 0; i < SUBLIST_SIZE; i++) {
+                    assertEquals(actualSubList.get(i), original.get(i + SUBLIST_FROM));
+                }
+            }
+
+            trimmedSubList(list, new Callback() {
+                @Override
+                public void call(final List<Integer> list) {
+                    final List<Integer> actual = new LinkedList<>();
+                    list.forEach(actual::add);
+                    CollectionAsserts.assertContents(actual, list);
+                }
+            });
         }
     }
 
     @Test
     public void testRemoveIf() throws Exception {
-        final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, 100);
+        final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, SIZE);
+
+        for (final CollectionSupplier.TestCase test : supplier.get()) {
+            final List<Integer> original = ((List<Integer>) test.original);
+            final List<Integer> list = ((List<Integer>) test.collection);
+            final AtomicInteger offset = new AtomicInteger(1);
+            while (list.size() > 0) {
+                removeFirst(original, list, offset);
+            }
+        }
+
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final List<Integer> original = ((List<Integer>) test.original);
             final List<Integer> list = ((List<Integer>) test.collection);
@@ -101,12 +153,66 @@
             list.removeIf(pEven);
             assertTrue(list.isEmpty());
         }
+
+        for (final CollectionSupplier.TestCase test : supplier.get()) {
+            final List<Integer> original = ((List<Integer>) test.original);
+            final List<Integer> list = ((List<Integer>) test.collection);
+            final List<Integer> listCopy = new ArrayList<>(list);
+            if (original.size() > SUBLIST_SIZE) {
+                final List<Integer> subList = list.subList(SUBLIST_FROM, SUBLIST_TO);
+                final List<Integer> subListCopy = new ArrayList<>(subList);
+                listCopy.removeAll(subList);
+                subList.removeIf(pOdd);
+                for (int i : subList) {
+                    assertTrue((i % 2) == 0);
+                }
+                for (int i : subListCopy) {
+                    if (i % 2 == 0) {
+                        assertTrue(subList.contains(i));
+                    } else {
+                        assertFalse(subList.contains(i));
+                    }
+                }
+                subList.removeIf(pEven);
+                assertTrue(subList.isEmpty());
+                // elements outside the view should remain
+                CollectionAsserts.assertContents(list, listCopy);
+            }
+        }
+
+        for (final CollectionSupplier.TestCase test : supplier.get()) {
+            final List<Integer> list = ((List<Integer>) test.collection);
+            trimmedSubList(list, new Callback() {
+                @Override
+                public void call(final List<Integer> list) {
+                    final List<Integer> copy = new ArrayList<>(list);
+                    list.removeIf(pOdd);
+                    for (int i : list) {
+                        assertTrue((i % 2) == 0);
+                    }
+                    for (int i : copy) {
+                        if (i % 2 == 0) {
+                            assertTrue(list.contains(i));
+                        } else {
+                            assertFalse(list.contains(i));
+                        }
+                    }
+                }
+            });
+        }
+    }
+
+    // remove the first element
+    private void removeFirst(final List<Integer> original, final List<Integer> list, final AtomicInteger offset) {
+        final AtomicBoolean first = new AtomicBoolean(true);
+        list.removeIf(x -> first.getAndSet(false));
+        CollectionAsserts.assertContents(original.subList(offset.getAndIncrement(), original.size()), list);
     }
 
     @Test
     public void testReplaceAll() throws Exception {
         final int scale = 3;
-        final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, 100);
+        final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, SIZE);
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final List<Integer> original = ((List<Integer>) test.original);
             final List<Integer> list = ((List<Integer>) test.collection);
@@ -114,12 +220,47 @@
             for (int i=0; i < original.size(); i++) {
                 assertTrue(list.get(i) == (scale * original.get(i)), "mismatch at index " + i);
             }
+
+            if (original.size() > SUBLIST_SIZE) {
+                final List<Integer> subList = list.subList(SUBLIST_FROM, SUBLIST_TO);
+                subList.replaceAll(x -> x + 1);
+                // verify elements in view [from, to) were replaced
+                for (int i = 0; i < SUBLIST_SIZE; i++) {
+                    assertTrue(subList.get(i) == ((scale * original.get(i + SUBLIST_FROM)) + 1),
+                            "mismatch at sublist index " + i);
+                }
+                // verify that elements [0, from) remain unmodified
+                for (int i = 0; i < SUBLIST_FROM; i++) {
+                    assertTrue(list.get(i) == (scale * original.get(i)),
+                            "mismatch at original index " + i);
+                }
+                // verify that elements [to, size) remain unmodified
+                for (int i = SUBLIST_TO; i < list.size(); i++) {
+                    assertTrue(list.get(i) == (scale * original.get(i)),
+                            "mismatch at original index " + i);
+                }
+            }
+        }
+
+        for (final CollectionSupplier.TestCase test : supplier.get()) {
+            final List<Integer> list = ((List<Integer>) test.collection);
+            trimmedSubList(list, new Callback() {
+                @Override
+                public void call(final List<Integer> list) {
+                    final List<Integer> copy = new ArrayList<>(list);
+                    final int offset = 5;
+                    list.replaceAll(x -> offset + x);
+                    for (int i=0; i < copy.size(); i++) {
+                        assertTrue(list.get(i) == (offset + copy.get(i)), "mismatch at index " + i);
+                    }
+                }
+            });
         }
     }
 
     @Test
     public void testSort() throws Exception {
-        final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, 100);
+        final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, SIZE);
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final List<Integer> original = ((List<Integer>) test.original);
             final List<Integer> list = ((List<Integer>) test.collection);
@@ -173,12 +314,43 @@
             for (int i=0; i < test.original.size(); i++) {
                 assertEquals(i, incomparables.get(i).intValue());
             }
+
+            if (original.size() > SUBLIST_SIZE) {
+                final List<Integer> copy = new ArrayList<>(list);
+                final List<Integer> subList = list.subList(SUBLIST_FROM, SUBLIST_TO);
+                CollectionSupplier.shuffle(subList);
+                subList.sort(Comparators.<Integer>naturalOrder());
+                CollectionAsserts.assertSorted(subList, Comparators.<Integer>naturalOrder());
+                // verify that elements [0, from) remain unmodified
+                for (int i = 0; i < SUBLIST_FROM; i++) {
+                    assertTrue(list.get(i) == copy.get(i),
+                            "mismatch at index " + i);
+                }
+                // verify that elements [to, size) remain unmodified
+                for (int i = SUBLIST_TO; i < list.size(); i++) {
+                    assertTrue(list.get(i) == copy.get(i),
+                            "mismatch at index " + i);
+                }
+            }
+        }
+
+        for (final CollectionSupplier.TestCase test : supplier.get()) {
+            final List<Integer> list = ((List<Integer>) test.collection);
+            trimmedSubList(list, new Callback() {
+                @Override
+                public void call(final List<Integer> list) {
+                    final List<Integer> copy = new ArrayList<>(list);
+                    CollectionSupplier.shuffle(list);
+                    list.sort(Comparators.<Integer>naturalOrder());
+                    CollectionAsserts.assertSorted(list, Comparators.<Integer>naturalOrder());
+                }
+            });
         }
     }
 
     @Test
     public void testRemoveIfThrowsCME() throws Exception {
-        final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, 100);
+        final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, SIZE);
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final List<Integer> list = ((List<Integer>) test.collection);
             if (list.size() <= 1) {
@@ -199,7 +371,7 @@
 
     @Test
     public void testReplaceAllThrowsCME() throws Exception {
-        final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, 100);
+        final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, SIZE);
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final List<Integer> list = ((List<Integer>) test.collection);
             if (list.size() <= 1) {
@@ -220,7 +392,7 @@
 
     @Test
     public void testSortThrowsCME() throws Exception {
-        final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, 100);
+        final CollectionSupplier supplier = new CollectionSupplier(LIST_CME_CLASSES, SIZE);
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final List<Integer> list = ((List<Integer>) test.collection);
             if (list.size() <= 1) {