001/*- 002 * Copyright 2016 Diamond Light Source Ltd. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 */ 009 010package org.eclipse.january.dataset; 011 012import java.util.ArrayList; 013import java.util.Arrays; 014import java.util.List; 015 016public final class BroadcastUtils { 017 018 /** 019 * Calculate shapes for broadcasting 020 * @param oldShape 021 * @param size 022 * @param newShape 023 * @return broadcasted shape and full new shape or null if it cannot be done 024 */ 025 public static int[][] calculateBroadcastShapes(int[] oldShape, int size, int... newShape) { 026 if (newShape == null) 027 return null; 028 029 int brank = newShape.length; 030 if (brank == 0) { 031 if (size == 1) 032 return new int[][] {oldShape, newShape}; 033 return null; 034 } 035 036 if (Arrays.equals(oldShape, newShape)) 037 return new int[][] {oldShape, newShape}; 038 039 int offset = brank - oldShape.length; 040 if (offset < 0) { // when new shape is incomplete 041 newShape = padShape(newShape, -offset); 042 offset = 0; 043 } 044 045 int[] bshape; 046 if (offset > 0) { // new shape has extra dimensions 047 bshape = padShape(oldShape, offset); 048 } else { 049 bshape = oldShape; 050 } 051 052 for (int i = 0; i < brank; i++) { 053 if (newShape[i] != bshape[i] && bshape[i] != 1 && newShape[i] != 1) { 054 return null; 055 } 056 } 057 058 return new int[][] {bshape, newShape}; 059 } 060 061 /** 062 * Pad shape by prefixing with ones 063 * @param shape 064 * @param padding 065 * @return new shape or old shape if padding is zero 066 */ 067 public static int[] padShape(final int[] shape, final int padding) { 068 if (padding < 0) 069 throw new IllegalArgumentException("Padding must be zero or greater"); 070 071 if (padding == 0) 072 return shape; 073 074 final int[] nshape = new int[shape.length + padding]; 075 Arrays.fill(nshape, 1); 076 System.arraycopy(shape, 0, nshape, padding, shape.length); 077 return nshape; 078 } 079 080 /** 081 * Take in shapes and broadcast them to same rank 082 * @param shapes 083 * @return list of broadcasted shapes plus the first entry is the maximum shape 084 */ 085 public static List<int[]> broadcastShapes(int[]... shapes) { 086 int maxRank = -1; 087 for (int[] s : shapes) { 088 if (s == null) 089 continue; 090 091 int r = s.length; 092 if (r > maxRank) { 093 maxRank = r; 094 } 095 } 096 097 List<int[]> newShapes = new ArrayList<int[]>(); 098 for (int[] s : shapes) { 099 if (s == null) 100 continue; 101 newShapes.add(padShape(s, maxRank - s.length)); 102 } 103 104 int[] maxShape = new int[maxRank]; 105 for (int i = 0; i < maxRank; i++) { 106 int m = -1; 107 for (int[] s : newShapes) { 108 int l = s[i]; 109 if (l > m) { 110 if (m > 1) { 111 throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum"); 112 } 113 m = l; 114 } 115 } 116 maxShape[i] = m; 117 } 118 119 checkShapes(maxShape, newShapes); 120 newShapes.add(0, maxShape); 121 return newShapes; 122 } 123 124 /** 125 * Take in shapes and broadcast them to maximum shape 126 * @param maxShape 127 * @param shapes 128 * @return list of broadcasted shapes 129 */ 130 public static List<int[]> broadcastShapesToMax(int[] maxShape, int[]... shapes) { 131 int maxRank = maxShape.length; 132 for (int[] s : shapes) { 133 if (s == null) 134 continue; 135 136 int r = s.length; 137 if (r > maxRank) { 138 throw new IllegalArgumentException("A shape exceeds given rank of maximum shape"); 139 } 140 } 141 142 List<int[]> newShapes = new ArrayList<int[]>(); 143 for (int[] s : shapes) { 144 if (s == null) 145 continue; 146 newShapes.add(padShape(s, maxRank - s.length)); 147 } 148 149 checkShapes(maxShape, newShapes); 150 return newShapes; 151 } 152 153 private static void checkShapes(int[] maxShape, List<int[]> newShapes) { 154 for (int i = 0; i < maxShape.length; i++) { 155 int m = maxShape[i]; 156 for (int[] s : newShapes) { 157 int l = s[i]; 158 if (l != 1 && l != m) { 159 throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum"); 160 } 161 } 162 } 163 } 164 165 @SuppressWarnings("deprecation") 166 static Dataset createDataset(final Dataset a, final Dataset b, final int[] shape) { 167 final int rt; 168 final int ar = a.getRank(); 169 final int br = b.getRank(); 170 final int tt = DTypeUtils.getBestDType(a.getDType(), b.getDType()); 171 if (ar == 0 ^ br == 0) { // ignore type of zero-rank dataset unless it's floating point 172 if (ar == 0) { 173 rt = a.hasFloatingPointElements() ? tt : b.getDType(); 174 } else { 175 rt = b.hasFloatingPointElements() ? tt : a.getDType(); 176 } 177 } else { 178 rt = tt; 179 } 180 final int ia = a.getElementsPerItem(); 181 final int ib = b.getElementsPerItem(); 182 183 return DatasetFactory.zeros(ia > ib ? ia : ib, shape, rt); 184 } 185 186 /** 187 * Check if dataset item sizes are compatible 188 * <p> 189 * Dataset a is considered compatible with the output dataset if any of the 190 * conditions are true: 191 * <ul> 192 * <li>o is undefined</li> 193 * <li>a has item size equal to o's</li> 194 * <li>a has item size equal to 1</li> 195 * <li>o has item size equal to 1</li> 196 * </ul> 197 * @param a input dataset a 198 * @param o output dataset (can be null) 199 */ 200 static void checkItemSize(Dataset a, Dataset o) { 201 final int isa = a.getElementsPerItem(); 202 if (o != null) { 203 final int iso = o.getElementsPerItem(); 204 if (isa != iso && isa != 1 && iso != 1) { 205 throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'"); 206 } 207 } 208 } 209 210 /** 211 * Check if dataset item sizes are compatible 212 * <p> 213 * Dataset a is considered compatible with the output dataset if any of the 214 * conditions are true: 215 * <ul> 216 * <li>a has item size equal to b's</li> 217 * <li>a has item size equal to 1</li> 218 * <li>b has item size equal to 1</li> 219 * <li>a or b are single-valued</li> 220 * </ul> 221 * and, o is undefined, or any of the following are true: 222 * <ul> 223 * <li>o has item size equal to maximum of a and b's</li> 224 * <li>o has item size equal to 1</li> 225 * <li>a and b have item sizes of 1</li> 226 * </ul> 227 * @param a input dataset a 228 * @param b input dataset b 229 * @param o output dataset 230 */ 231 static void checkItemSize(Dataset a, Dataset b, Dataset o) { 232 final int isa = a.getElementsPerItem(); 233 final int isb = b.getElementsPerItem(); 234 if (isa != isb && isa != 1 && isb != 1) { 235 // exempt single-value dataset case too 236 if ((isa == 1 || b.getSize() != 1) && (isb == 1 || a.getSize() != 1) ) { 237 throw new IllegalArgumentException("Can not broadcast where number of elements per item mismatch and one does not equal another"); 238 } 239 } 240 if (o != null && o.getDType() != Dataset.BOOL) { 241 final int ism = Math.max(isa, isb); 242 final int iso = o.getElementsPerItem(); 243 if (iso != ism && iso != 1 && ism != 1) { 244 throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'"); 245 } 246 } 247 } 248 249 /** 250 * Create a stride array from a dataset to a broadcast shape 251 * @param a dataset 252 * @param broadcastShape 253 * @return stride array 254 */ 255 public static int[] createBroadcastStrides(Dataset a, final int[] broadcastShape) { 256 return createBroadcastStrides(a.getElementsPerItem(), a.getShapeRef(), a.getStrides(), broadcastShape); 257 } 258 259 /** 260 * Create a stride array from a dataset to a broadcast shape 261 * @param isize 262 * @param oShape original shape 263 * @param oStride original stride 264 * @param broadcastShape 265 * @return stride array 266 */ 267 public static int[] createBroadcastStrides(final int isize, final int[] oShape, final int[] oStride, final int[] broadcastShape) { 268 int rank = oShape.length; 269 if (broadcastShape.length != rank) { 270 throw new IllegalArgumentException("Dataset must have same rank as broadcast shape"); 271 } 272 273 int[] stride = new int[rank]; 274 if (oStride == null) { 275 int s = isize; 276 for (int j = rank - 1; j >= 0; j--) { 277 if (broadcastShape[j] == oShape[j]) { 278 stride[j] = s; 279 s *= oShape[j]; 280 } else { 281 stride[j] = 0; 282 } 283 } 284 } else { 285 for (int j = 0; j < rank; j++) { 286 if (broadcastShape[j] == oShape[j]) { 287 stride[j] = oStride[j]; 288 } else { 289 stride[j] = 0; 290 } 291 } 292 } 293 294 return stride; 295 } 296 297 /** 298 * Converts and broadcast all objects as datasets of same shape 299 * @param objects 300 * @return all as broadcasted to same shape 301 */ 302 public static Dataset[] convertAndBroadcast(Object... objects) { 303 final int n = objects.length; 304 305 Dataset[] datasets = new Dataset[n]; 306 int[][] shapes = new int[n][]; 307 for (int i = 0; i < n; i++) { 308 Dataset d = DatasetFactory.createFromObject(objects[i]); 309 datasets[i] = d; 310 shapes[i] = d.getShapeRef(); 311 } 312 313 List<int[]> nShapes = BroadcastUtils.broadcastShapes(shapes); 314 int[] mshape = nShapes.get(0); 315 for (int i = 0; i < n; i++) { 316 datasets[i] = datasets[i].getBroadcastView(mshape); 317 } 318 319 return datasets; 320 } 321}