001/*- 002 * Copyright 2016 Diamond Light Source Ltd. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 */ 009 010package org.eclipse.january.dataset; 011 012import java.util.ArrayList; 013import java.util.Arrays; 014import java.util.List; 015 016public final class BroadcastUtils { 017 018 /** 019 * Calculate shapes for broadcasting 020 * @param oldShape 021 * @param size 022 * @param newShape 023 * @return broadcasted shape and full new shape or null if it cannot be done 024 */ 025 public static int[][] calculateBroadcastShapes(int[] oldShape, int size, int... newShape) { 026 if (newShape == null) 027 return null; 028 029 int brank = newShape.length; 030 if (brank == 0) { 031 if (size == 1) 032 return new int[][] {oldShape, newShape}; 033 return null; 034 } 035 036 if (Arrays.equals(oldShape, newShape)) 037 return new int[][] {oldShape, newShape}; 038 039 int offset = brank - oldShape.length; 040 if (offset < 0) { // when new shape is incomplete 041 newShape = padShape(newShape, -offset); 042 offset = 0; 043 } 044 045 int[] bshape; 046 if (offset > 0) { // new shape has extra dimensions 047 bshape = padShape(oldShape, offset); 048 } else { 049 bshape = oldShape; 050 } 051 052 for (int i = 0; i < brank; i++) { 053 if (newShape[i] != bshape[i] && bshape[i] != 1 && newShape[i] != 1) { 054 return null; 055 } 056 } 057 058 return new int[][] {bshape, newShape}; 059 } 060 061 /** 062 * Pad shape by prefixing with ones 063 * @param shape 064 * @param padding 065 * @return new shape or old shape if padding is zero 066 */ 067 public static int[] padShape(final int[] shape, final int padding) { 068 if (padding < 0) 069 throw new IllegalArgumentException("Padding must be zero or greater"); 070 071 if (padding == 0) 072 return shape; 073 074 final int[] nshape = new int[shape.length + padding]; 075 Arrays.fill(nshape, 1); 076 System.arraycopy(shape, 0, nshape, padding, shape.length); 077 return nshape; 078 } 079 080 /** 081 * Take in shapes and broadcast them to same rank 082 * @param shapes 083 * @return list of broadcasted shapes plus the first entry is the maximum shape 084 */ 085 public static List<int[]> broadcastShapes(int[]... shapes) { 086 int maxRank = -1; 087 for (int[] s : shapes) { 088 if (s == null) 089 continue; 090 091 int r = s.length; 092 if (r > maxRank) { 093 maxRank = r; 094 } 095 } 096 097 List<int[]> newShapes = new ArrayList<int[]>(); 098 for (int[] s : shapes) { 099 if (s == null) 100 continue; 101 newShapes.add(padShape(s, maxRank - s.length)); 102 } 103 104 int[] maxShape = new int[maxRank]; 105 for (int i = 0; i < maxRank; i++) { 106 int m = -1; 107 for (int[] s : newShapes) { 108 int l = s[i]; 109 if (l > m) { 110 if (m > 1) { 111 throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum"); 112 } 113 m = l; 114 } 115 } 116 maxShape[i] = m; 117 } 118 119 checkShapes(maxShape, newShapes); 120 newShapes.add(0, maxShape); 121 return newShapes; 122 } 123 124 /** 125 * Take in shapes and broadcast them to maximum shape 126 * @param maxShape 127 * @param shapes 128 * @return list of broadcasted shapes 129 */ 130 public static List<int[]> broadcastShapesToMax(int[] maxShape, int[]... shapes) { 131 int maxRank = maxShape.length; 132 for (int[] s : shapes) { 133 if (s == null) 134 continue; 135 136 int r = s.length; 137 if (r > maxRank) { 138 throw new IllegalArgumentException("A shape exceeds given rank of maximum shape"); 139 } 140 } 141 142 List<int[]> newShapes = new ArrayList<int[]>(); 143 for (int[] s : shapes) { 144 if (s == null) 145 continue; 146 newShapes.add(padShape(s, maxRank - s.length)); 147 } 148 149 checkShapes(maxShape, newShapes); 150 return newShapes; 151 } 152 153 private static void checkShapes(int[] maxShape, List<int[]> newShapes) { 154 for (int i = 0; i < maxShape.length; i++) { 155 int m = maxShape[i]; 156 for (int[] s : newShapes) { 157 int l = s[i]; 158 if (l != 1 && l != m) { 159 throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum"); 160 } 161 } 162 } 163 } 164 165 @SuppressWarnings("deprecation") 166 static Dataset createDataset(final Dataset a, final Dataset b, final int[] shape) { 167 final int rt; 168 final int ar = a.getRank(); 169 final int br = b.getRank(); 170 final int tt = DTypeUtils.getBestDType(a.getDType(), b.getDType()); 171 if (ar == 0 ^ br == 0) { // ignore type of zero-rank dataset unless it's floating point 172 if (ar == 0) { 173 rt = a.hasFloatingPointElements() ? tt : b.getDType(); 174 } else { 175 rt = b.hasFloatingPointElements() ? tt : a.getDType(); 176 } 177 } else { 178 rt = tt; 179 } 180 final int ia = a.getElementsPerItem(); 181 final int ib = b.getElementsPerItem(); 182 183 return DatasetFactory.zeros(ia > ib ? ia : ib, shape, rt); 184 } 185 186 static void checkItemSize(Dataset a, Dataset b, Dataset o) { 187 final int isa = a.getElementsPerItem(); 188 final int isb = b.getElementsPerItem(); 189 if (isa != isb && isa != 1 && isb != 1) { 190 // exempt single-value dataset case too 191 if ((isa == 1 || b.getSize() != 1) && (isb == 1 || a.getSize() != 1) ) { 192 throw new IllegalArgumentException("Can not broadcast where number of elements per item mismatch and one does not equal another"); 193 } 194 } 195 if (o != null && o.getDType() != Dataset.BOOL) { 196 final int ism = Math.max(isa, isb); 197 final int iso = o.getElementsPerItem(); 198 if (iso != ism && ism != 1) { 199 throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'"); 200 } 201 } 202 } 203 204 /** 205 * Create a stride array from a dataset to a broadcast shape 206 * @param a dataset 207 * @param broadcastShape 208 * @return stride array 209 */ 210 public static int[] createBroadcastStrides(Dataset a, final int[] broadcastShape) { 211 return createBroadcastStrides(a.getElementsPerItem(), a.getShapeRef(), a.getStrides(), broadcastShape); 212 } 213 214 /** 215 * Create a stride array from a dataset to a broadcast shape 216 * @param isize 217 * @param oShape original shape 218 * @param oStride original stride 219 * @param broadcastShape 220 * @return stride array 221 */ 222 public static int[] createBroadcastStrides(final int isize, final int[] oShape, final int[] oStride, final int[] broadcastShape) { 223 int rank = oShape.length; 224 if (broadcastShape.length != rank) { 225 throw new IllegalArgumentException("Dataset must have same rank as broadcast shape"); 226 } 227 228 int[] stride = new int[rank]; 229 if (oStride == null) { 230 int s = isize; 231 for (int j = rank - 1; j >= 0; j--) { 232 if (broadcastShape[j] == oShape[j]) { 233 stride[j] = s; 234 s *= oShape[j]; 235 } else { 236 stride[j] = 0; 237 } 238 } 239 } else { 240 for (int j = 0; j < rank; j++) { 241 if (broadcastShape[j] == oShape[j]) { 242 stride[j] = oStride[j]; 243 } else { 244 stride[j] = 0; 245 } 246 } 247 } 248 249 return stride; 250 } 251 252 /** 253 * Converts and broadcast all objects as datasets of same shape 254 * @param objects 255 * @return all as broadcasted to same shape 256 */ 257 public static Dataset[] convertAndBroadcast(Object... objects) { 258 final int n = objects.length; 259 260 Dataset[] datasets = new Dataset[n]; 261 int[][] shapes = new int[n][]; 262 for (int i = 0; i < n; i++) { 263 Dataset d = DatasetFactory.createFromObject(objects[i]); 264 datasets[i] = d; 265 shapes[i] = d.getShapeRef(); 266 } 267 268 List<int[]> nShapes = BroadcastUtils.broadcastShapes(shapes); 269 int[] mshape = nShapes.get(0); 270 for (int i = 0; i < n; i++) { 271 datasets[i] = datasets[i].getBroadcastView(mshape); 272 } 273 274 return datasets; 275 } 276}