Line | Hits | Source |
---|---|---|
1 | /* | |
2 | * Copyright (c) 2003, the JUNG Project and the Regents of the University | |
3 | * of California | |
4 | * All rights reserved. | |
5 | * | |
6 | * This software is open-source under the BSD license; see either | |
7 | * "license.txt" or | |
8 | * http://jung.sourceforge.net/license.txt for a description. | |
9 | */ | |
10 | /* | |
11 | * Created on Aug 9, 2004 | |
12 | * | |
13 | */ | |
14 | package edu.uci.ics.jung.algorithms.cluster; | |
15 | ||
16 | import java.util.Arrays; | |
17 | import java.util.Collection; | |
18 | import java.util.HashMap; | |
19 | import java.util.HashSet; | |
20 | import java.util.Iterator; | |
21 | import java.util.Map; | |
22 | import java.util.Set; | |
23 | ||
24 | import cern.jet.random.engine.DRand; | |
25 | import cern.jet.random.engine.RandomEngine; | |
26 | import edu.uci.ics.jung.statistics.DiscreteDistribution; | |
27 | ||
28 | ||
29 | ||
30 | /** | |
31 | * Groups Objects into a specified number of clusters, based on their | |
32 | * proximity in d-dimensional space, using the k-means algorithm. | |
33 | * | |
34 | * @author Joshua O'Madadhain | |
35 | */ | |
36 | public class KMeansClusterer | |
37 | { | |
38 | protected int max_iterations; | |
39 | protected double convergence_threshold; | |
40 | 1 | protected RandomEngine rand = new DRand(); |
41 | ||
42 | /** | |
43 | * Creates an instance for which calls to <code>cluster</code> will terminate | |
44 | * when either of the two following conditions is true: | |
45 | * <ul> | |
46 | * <li/>the number of iterations is > <code>max_iterations</code> | |
47 | * <li/>none of the centroids has moved as much as <code>convergence_threshold</code> | |
48 | * since the previous iteration | |
49 | * </ul> | |
50 | * @param max_iterations | |
51 | * @param convergence_threshold | |
52 | */ | |
53 | public KMeansClusterer(int max_iterations, double convergence_threshold) | |
54 | 1 | { |
55 | 1 | if (max_iterations < 0) |
56 | 0 | throw new IllegalArgumentException("max iterations must be >= 0"); |
57 | ||
58 | 1 | if (convergence_threshold <= 0) |
59 | 0 | throw new IllegalArgumentException("convergence threshold " + |
60 | "must be > 0"); | |
61 | ||
62 | 1 | this.max_iterations = max_iterations; |
63 | 1 | this.convergence_threshold = convergence_threshold; |
64 | 1 | } |
65 | ||
66 | /** | |
67 | * Returns a <code>Collection</code> of clusters, where each cluster is | |
68 | * represented as a <code>Map</code> of <code>Objects</code> to locations | |
69 | * in d-dimensional space. | |
70 | * @param object_locations a map of the Objects to cluster, to | |
71 | * <code>double</code> arrays that specify their locations in d-dimensional space. | |
72 | * @param num_clusters the number of clusters to create | |
73 | * @throws NotEnoughClustersException | |
74 | */ | |
75 | public Collection cluster(Map object_locations, int num_clusters) | |
76 | { | |
77 | 3 | if (num_clusters < 2 || num_clusters > object_locations.size()) |
78 | 1 | throw new IllegalArgumentException("number of clusters " + |
79 | "must be >= 2 and <= number of objects (" + | |
80 | object_locations.size() + ")"); | |
81 | ||
82 | 2 | if (object_locations == null || object_locations.isEmpty()) |
83 | 0 | throw new IllegalArgumentException("'objects' must be non-empty"); |
84 | ||
85 | 2 | Set centroids = new HashSet(); |
86 | 2 | Object[] obj_array = object_locations.keySet().toArray(); |
87 | 2 | Set tried = new HashSet(); |
88 | ||
89 | // create the specified number of clusters | |
90 | 12 | while (centroids.size() < num_clusters && tried.size() < object_locations.size()) |
91 | { | |
92 | 10 | Object o = obj_array[(int)(rand.nextDouble() * obj_array.length)]; |
93 | 10 | tried.add(o); |
94 | 10 | double[] mean_value = (double[])object_locations.get(o); |
95 | 10 | boolean duplicate = false; |
96 | 10 | for (Iterator iter = centroids.iterator(); iter.hasNext(); ) |
97 | { | |
98 | 8 | double[] cur = (double[])iter.next(); |
99 | 8 | if (Arrays.equals(mean_value, cur)) |
100 | 6 | duplicate = true; |
101 | } | |
102 | 10 | if (!duplicate) |
103 | 4 | centroids.add(mean_value); |
104 | } | |
105 | ||
106 | 2 | if (tried.size() >= object_locations.size()) |
107 | 1 | throw new NotEnoughClustersException(); |
108 | ||
109 | // put items in their initial clusters | |
110 | 1 | Map clusterMap = assignToClusters(object_locations, centroids); |
111 | ||
112 | // keep reconstituting clusters until either | |
113 | // (a) membership is stable, or | |
114 | // (b) number of iterations passes max_iterations, or | |
115 | // (c) max movement of any centroid is <= convergence_threshold | |
116 | 1 | int iterations = 0; |
117 | 1 | double max_movement = Double.POSITIVE_INFINITY; |
118 | 3 | while (iterations++ < max_iterations && max_movement > convergence_threshold) |
119 | { | |
120 | 2 | max_movement = 0; |
121 | 2 | Set new_centroids = new HashSet(); |
122 | // calculate new mean for each cluster | |
123 | 2 | for (Iterator iter = clusterMap.keySet().iterator(); iter.hasNext(); ) |
124 | { | |
125 | 4 | double[] centroid = (double[])iter.next(); |
126 | 4 | Map elements = (Map)clusterMap.get(centroid); |
127 | 4 | double[][] locations = new double[elements.size()][]; |
128 | 4 | int i = 0; |
129 | 4 | for (Iterator e_iter = elements.keySet().iterator(); e_iter.hasNext(); ) |
130 | 10 | locations[i++] = (double[])object_locations.get(e_iter.next()); |
131 | ||
132 | 4 | double[] mean = DiscreteDistribution.mean(locations); |
133 | 4 | max_movement = Math.max(max_movement, |
134 | Math.sqrt(DiscreteDistribution.squaredError(centroid, mean))); | |
135 | 4 | new_centroids.add(mean); |
136 | } | |
137 | ||
138 | // TODO: check membership of clusters: have they changed? | |
139 | ||
140 | // regenerate cluster membership based on means | |
141 | 2 | clusterMap = assignToClusters(object_locations, new_centroids); |
142 | } | |
143 | 1 | return (Collection)clusterMap.values(); |
144 | } | |
145 | ||
146 | /** | |
147 | * Assigns each object to the cluster whose centroid is closest to the | |
148 | * object. | |
149 | * @param object_locations a map of objects to locations | |
150 | * @param centroids the centroids of the clusters to be formed | |
151 | * @return a map of objects to assigned clusters | |
152 | */ | |
153 | protected Map assignToClusters(Map object_locations, Set centroids) | |
154 | { | |
155 | 3 | Map clusterMap = new HashMap(); |
156 | 3 | for (Iterator c_iter = centroids.iterator(); c_iter.hasNext(); ) |
157 | 6 | clusterMap.put(c_iter.next(), new HashMap()); |
158 | ||
159 | 3 | for (Iterator o_iter = object_locations.keySet().iterator(); o_iter.hasNext(); ) |
160 | { | |
161 | 15 | Object o = o_iter.next(); |
162 | 15 | double[] location = (double[])object_locations.get(o); |
163 | ||
164 | // find the cluster with the closest centroid | |
165 | 15 | Iterator c_iter = centroids.iterator(); |
166 | 15 | double[] closest = (double[])c_iter.next(); |
167 | 15 | double distance = DiscreteDistribution.squaredError(location, closest); |
168 | ||
169 | 30 | while (c_iter.hasNext()) |
170 | { | |
171 | 15 | double[] centroid = (double[])c_iter.next(); |
172 | 15 | double dist_cur = DiscreteDistribution.squaredError(location, centroid); |
173 | 15 | if (dist_cur < distance) |
174 | { | |
175 | 8 | distance = dist_cur; |
176 | 8 | closest = centroid; |
177 | } | |
178 | } | |
179 | 15 | Map elements = (Map)clusterMap.get(closest); |
180 | 15 | elements.put(o, location); |
181 | } | |
182 | ||
183 | 3 | return clusterMap; |
184 | } | |
185 | ||
186 | public void setSeed(int random_seed) | |
187 | { | |
188 | 0 | this.rand = new DRand(random_seed); |
189 | 0 | } |
190 | ||
191 | /** | |
192 | * An exception that indicates that the specified data points cannot be | |
193 | * clustered into the number of clusters requested by the user. | |
194 | * This will happen if and only if there are fewer distinct points than | |
195 | * requested clusters. (If there are fewer total data points than | |
196 | * requested clusters, <code>IllegalArgumentException</code> will be thrown.) | |
197 | * | |
198 | * @author Joshua O'Madadhain | |
199 | */ | |
200 | public static class NotEnoughClustersException extends RuntimeException | |
201 | { | |
202 | public String getMessage() | |
203 | { | |
204 | return "Not enough distinct points in the input data set to form " + | |
205 | "the requested number of clusters"; | |
206 | } | |
207 | } | |
208 | } |
this report was generated by version 1.0.5 of jcoverage. |
copyright © 2003, jcoverage ltd. all rights reserved. |