余弦相似性算法

余弦相似性算法的具体介绍参考:http://www.ruanyifeng.com/blog/2013/03/cosine_similarity.html

下面是我根据上边的介绍进行的java语言的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

import java.io.IOException;
import java.io.StringReader;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.wltea.analyzer.lucene.IKAnalyzer;

import com.wjb.util.common.WjbTuple2;

public class CosineTextSimilarity {

public static Map<String, Integer> makeTermFrequency(String text) throws IOException
{
Analyzer analyzer = new IKAnalyzer(true);
StringReader reader = new StringReader(text);
TokenStream ts = analyzer.tokenStream("", reader);
CharTermAttribute term=ts.getAttribute(CharTermAttribute.class);
Map<String,Integer> tf = new HashMap<String, Integer>();
while(ts.incrementToken()){
String t = term.toString();
Integer count = tf.get(t);
if(count == null)
{
tf.put(t, 1);
}else{
tf.put(t, count + 1);
}
}
analyzer.close();
reader.close();
return tf;
}

/**
* 根据key的长度进行过滤,只有key的长度不小于 length 时, 这个key才会保留
* @param map
* @param length
* @return
* @throws IOException
*/
public static Map<String, Integer> filterByKeyLength(Map<String, Integer> map , int length) throws IOException
{
Map<String, Integer> m = new HashMap<String, Integer>();
for(String key : map.keySet())
{
if(key == null || key.trim().length() >= length)
{
m.put(key, map.get(key));
}
}
return m;
}

public static WjbTuple2<int[], int[]> makeVector(Map<String, Integer> first,Map<String, Integer> second){
Set<String> keys = new HashSet<String>();
keys.addAll(first.keySet());
keys.addAll(second.keySet());
int[] vector1 = new int[keys.size()];
int[] vector2 = new int[keys.size()];
int i = 0;
for(String key : keys)
{
Integer count1 = first.get(key);
if(count1 != null)
{
vector1[i] = count1;
}
Integer count2 = second.get(key);
if(count2 != null)
{
vector2[i] = count2;
}
i++;

}
return new WjbTuple2<int[], int[]>(vector1, vector2);
}



public static double cosine(WjbTuple2<int[], int[]> tuple)
{
int[] vector1 = tuple._1;
int[] vector2 = tuple._2;

double sum1 = 0;
double sum21 = 0;
double sum22 = 0;

for (int i = 0; i < vector1.length; i++) {
sum1 += vector1[i] * vector2[i];
sum21 += vector1[i] * vector1[i];
sum22 += vector2[i] * vector2[i];
}

return sum1/(Math.sqrt(sum21 * sum22 ));
}

public static List<Entry> sort(Map unsortMap) {

// Convert Map to List
List<Map.Entry> list = new LinkedList<Map.Entry>(unsortMap.entrySet());

// Sort list with comparator, to compare the Map values
Collections.sort(list, new Comparator<Map.Entry>() {
public int compare(Map.Entry o1,Map.Entry o2) {
String d1 = o1.getValue().toString();
String d2 = o2.getValue().toString();
String k1 = o1.getKey().toString();
String k2 = o2.getKey().toString();
if(o1.getValue() instanceof Integer)
{
Integer nd1 = Integer.parseInt(d1);
Integer nd2 = Integer.parseInt(d2);
if( nd2 - nd1 != 0 )
return nd2 - nd1;
else{
return k2.compareTo(k1);
}
}else
return d2.compareTo(d1);
}
});

return list;
}
}

下面是main方法,进行测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import com.wjb.util.common.WjbFileUtil;
import com.wjb.util.common.WjbTuple2;

public class Main {
public static void main(String[] args) throws Exception {

String text1 = WjbFileUtil.fromFile("d:/1.txt");
String text2 = WjbFileUtil.fromFile("d:/2.txt" , WjbFileUtil.GBK);

System.out.println(text2);
long begin = System.currentTimeMillis();
Map<String, Integer> map1 = CosineTextSimilarity.makeTermFrequency(text1);
Map<String, Integer> map2 = CosineTextSimilarity.makeTermFrequency(text2);

// map1 = CosineTextSimilarity.filterByKeyLength(map1, 2);
// map2 = CosineTextSimilarity.filterByKeyLength(map2, 2);

List<Entry> list1 = CosineTextSimilarity.sort(map1);
System.out.println(list1);
list1 = list1.subList(0 , list1.size() > 20 ? 20 : list1.size());

List<Entry> list2 = CosineTextSimilarity.sort(map2);
System.out.println(list2);
list2 = list2.subList(0 , list2.size() > 20 ? 20 : list2.size());

map1 = list2Map(list1);
map2 = list2Map(list2);

WjbTuple2<int[], int[]> tuple = CosineTextSimilarity.makeVector(map1, map2);
double cos = CosineTextSimilarity.cosine(tuple);

long end = System.currentTimeMillis();

System.out.println(end - begin);

System.out.println(cos);
}

public static Map<String, Integer> list2Map(List<Entry> list)
{
Map<String, Integer> map = new HashMap<String, Integer>();
for(Entry e : list)
{
map.put(e.getKey().toString(), (Integer)e.getValue());
}
return map;
}
}